pass

来源:4-7 生成tfrecords文件

战战的坚果

2020-02-21

for x_batch, y_batch in dataset.take(steps_per_shard):
for x_example, y_example in zip(x_batch, y_batch):#解batch
老师,此时的数据集是之前的8个属性与类标签结合起来,相当于一行9个属性的形式,现在从数据集中取出前steps_per_shard个batch,我想请教的是:获取到steps_per_shard个batch后,是如何赋值给 x_batch, y_batch 的,又是怎样实现对batch解绑的,赋值给 x_example, y_example的,解绑是指将9个属性拆分成8+1吗?

写回答

1回答

正十七

2020-02-24

如下面的代码所示:

def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y
 
def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

在csv_reader_dataset中,我们在map函数中对每一行进行解析。

在parse_csv_line中,我们把每一行给拆成了前八个和后一个,即x和y。这样csv_reader_dataset返回的dataset里,每一个batch都是两个元素,即x_batch和y_batch。

而batch解绑定直接用了for循环,如下,并不是9拆8+1。

with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
    for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):                for x_example, y_example in zip(x_batch, y_batch):
        writer.write(serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)


0
1
战战的坚果
老师,以下是我的理解,您看对吗? 在csv_reader_dataset中,在map函数中对每一行进行解析。 在parse_csv_line中,把每一行给拆成了前八个和后一个,即x和y。这样csv_reader_dataset返回的dataset里,每一个batch都是两个元素,即x_batch:即batchsize条数据的前8个元素和y_batch:即batchsize条数据的后1个元素,则 for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard): 执行后,取出了steps_per_shard个batch,然后通过for循环取出一个batch,并将这一个batch(由两部分组成)赋值给x_batch, y_batch,此时x_batch:是batchsize条数据的前8个元素,y_batch是batchsize条数据的后1个元素,此时通过第二个for循环,将x_batch, y_batch连接起来,依次取出batchsize条数据的第一条、第二条。。。第batchsize条数据,将第一条数据的前8个元素赋值给x_example, 后一个元素赋值y_example,所以serialize_example(x_example, y_example)是将一行数据序列化。
2020-02-28
共1条回复

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程

相似问题

pass

回答 1

pass

回答 1

pass

回答 1

pass

回答 1

pass

回答 1