train函数的steps参数怎么设定?

来源:4-5 tf.data读取csv文件并与tf.keras结合使用

蚂蚁帅帅

2019-07-18

你好老师,我查了下tensorflow的官网的API:tf.estimator.LinearClassifier

它的train方法是这样的:
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)

我的疑问是这里的steps的值怎么确定啊?

看视频里的讲解,是不是我用这个公式计算比较好啊?
train_x.shape[0] // batch_size

其中train_x是训练输入,shape[0]我就取到了训练数据的行数;//batchsize就得到了数据被拆分的份数;

写回答

1回答

正十七

2019-07-21

train_x.shape[0] // batch_size 我们在keras的train函数中常常用到,它的作用是遍历数据集一次所需要的步数,然后keras.train函数里还有一个epoch参数来决定遍历数据集几次。

在estimator里,steps就是所有的步数,所以应该是总步数 train_x.shape[0] // batch_size * epochs 

0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程