打印损失值的回调函数
来源:2-6 实战回调函数
OCEANDREAM
2021-01-13
自定义了一个模型,继承了keras.Model,想利用.fit() 函数进行训练,所以重写了模型中的train_step函数,在train_step函数最后,以字典的形式,return了各个loss, 然后有写了一个回调函数def on_train_batch_end(self, batch, logs=None), 在每个batch结束的时候,用logs[‘loss’] 的形式打印了损失,但是为什么和调用.fit() 函数每个batch的损失值不同呢?损失函数是自定义的
写回答
1回答
-
正十七
2021-01-22
可能的原因是:
tf.keras.Model的fit函数里计算的是累积的metrics,即一个epoch上遍历时metrics在开始reset,然后过程中是累积的。而你的实现可能是单步的?
只是猜测,具体还需要看源码。
00
相似问题