分布式中的损失函数

来源:8-9 分布式自定义流程实战

紫梦沁香

2021-09-02

这里得到了每一个replica的平均损失,最后聚合为什么不是加起来再求平均呢?
图片描述

写回答

1回答

正十七

2021-09-04

因为底层的梯度在不同的replica上是加起来的,所以如果在每个replica上不除以global_batch_size, 会导致计算得到的梯度不正常。这里的先算均值再加的操作是follow梯度的计算方式。

参考https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function



0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程