打印损失值的回调函数

来源:2-6 实战回调函数

OCEANDREAM

2021-01-13

自定义了一个模型,继承了keras.Model,想利用.fit() 函数进行训练,所以重写了模型中的train_step函数,在train_step函数最后,以字典的形式,return了各个loss, 然后有写了一个回调函数def on_train_batch_end(self, batch, logs=None), 在每个batch结束的时候,用logs[‘loss’] 的形式打印了损失,但是为什么和调用.fit() 函数每个batch的损失值不同呢?损失函数是自定义的

写回答

1回答

正十七

2021-01-22

可能的原因是:

tf.keras.Model的fit函数里计算的是累积的metrics,即一个epoch上遍历时metrics在开始reset,然后过程中是累积的。而你的实现可能是单步的?

只是猜测,具体还需要看源码。

0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程