pass
来源:4-6 tfrecord基础API使用
战战的坚果
2020-02-18
def csv_reader_dataset(filenames, n_readers=5,
batch_size=32, n_parse_threads=5,
shuffle_buffer_size=10000):
dataset = tf.data.Dataset.list_files(filenames)
dataset = dataset.repeat()
dataset = dataset.interleave(
lambda filename: tf.data.TextLineDataset(filename).skip(1),
cycle_length = n_readers
)
dataset.shuffle(shuffle_buffer_size)
dataset = dataset.map(parse_csv_line,
num_parallel_calls=n_parse_threads)
dataset = dataset.batch(batch_size)
return dataset
train_set = csv_reader_dataset(train_filenames,
batch_size = batch_size)
valid_set = csv_reader_dataset(valid_filenames,
batch_size = batch_size)
test_set = csv_reader_dataset(test_filenames,
batch_size = batch_size)
老师,这段代码中 ***dataset = dataset.repeat()***将数据集重复了无限次,在后面对训练集,测试集,验证集都调用了此函数,所以对于训练集,测试集,验证集都需要指定在每一个epoch中需要执行的步数,对吗?所以:
history = model.fit(train_set,
validation_data = valid_set,
steps_per_epoch = 11160 // batch_size,
validation_steps = 3870 // batch_size,
epochs = 100,
callbacks = callbacks)
#%%
model.evaluate(test_set, steps = 5160 // batch_size)
所以在上述代码中,有所示的三条语句, steps_per_epoch = 11160 // batch_size,
validation_steps = 3870 // batch_size,
steps = 5160 // batch_size
老师,我理解的对吗?
1回答
-
正十七
2020-02-24
同学,你理解的是对的,训练集,测试集和验证集都因为调用了repeat而变得无限大,所以为了使fit函数知道多少次是一个epoch,我做了一个trick,即预先查了训练集,测试集,验证集各有多少样本,然后计算出了每一个的step_per_epoch.
00