什么样子的数据集需要repeat()呢

来源:4-5 tf.data读取csv文件并与tf.keras结合使用

weixin_慕勒7004644

2021-01-05

在之前用的california_housing dataset直接导入时,没有用到repeat(),为什么转成csv文件再次导入时需要repeat()呢?感谢老师!

def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset
写回答

1回答

正十七

2021-01-07

我在后面的逻辑里使用的具体的样本数字来控制的输出,所以这里的repeat可加可不加,只要后面保证dataset只被遍历一遍就可以。

1
1
weixin_慕勒7004644
非常感谢!
2021-01-07
共1条回复

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程