tf.GradientTape()的理解
来源:3-12 tf.GradientTape与tf.keras结合使用
wxz123
2019-11-02
老师对于tf.GradientTape()这个函数我还是有点懵,它的主要作用是让loss可以对model.variables求梯度值吗?那为什么还要让
y_pred = model(x_batch)
y_pred = tf.squeeze(y_pred, 1)
metric(y_batch, y_pred)
这三句也要在with tf.GradientTape() as tape的作用域当中呢
写回答
1回答
-
正十七
2019-11-17
同学你好,让这三个操作放在tf.GradientTape的作用域当中是因为GradientTape需要监控要做梯度的操作。不然就会得到报错:
No gradients provided for any variable: ['dense_2/kernel:0', 'dense_2/bias:0', 'dense_3/kernel:0', 'dense_3/bias:0']
https://www.tensorflow.org/api_docs/python/tf/GradientTape
00
相似问题