分布式中的损失函数
来源:8-9 分布式自定义流程实战

紫梦沁香
2021-09-02
这里得到了每一个replica的平均损失,最后聚合为什么不是加起来再求平均呢?
写回答
1回答
-
正十七
2021-09-04
因为底层的梯度在不同的replica上是加起来的,所以如果在每个replica上不除以global_batch_size, 会导致计算得到的梯度不正常。这里的先算均值再加的操作是follow梯度的计算方式。
参考https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
00
相似问题