pass

来源:3-12 tf.GradientTape与tf.keras结合使用

战战的坚果

2020-04-29

老师,我看到的本小节的代码和您课上讲的有些不同,您能解释一下代码的修改部分吗?
1、 loss = keras.losses.mean_squared_error(y_batch, y_pred)中去掉了tf.reduce_mean?
2、 y_valid_pred = tf.squeeze(y_valid_pred, 1),多加了这句,是什么意思呢?

for epoch in range(epochs):
    metric.reset_states()
    for step in range(steps_per_epoch):
        x_batch, y_batch = random_batch(x_train_scaled, y_train,
                                        batch_size)
        with tf.GradientTape() as tape:
            y_pred = model(x_batch)
            y_pred = tf.squeeze(y_pred, 1)
            loss = keras.losses.mean_squared_error(y_batch, y_pred)
            metric(y_batch, y_pred)
        grads = tape.gradient(loss, model.variables)
        grads_and_vars = zip(grads, model.variables)
        optimizer.apply_gradients(grads_and_vars)
        print("
Epoch", epoch, " train mse:",
              metric.result().numpy(), end="")
    y_valid_pred = model(x_valid_scaled)
    y_valid_pred = tf.squeeze(y_valid_pred, 1)
    valid_loss = keras.losses.mean_squared_error(y_valid_pred, y_valid)
    print("	", "valid mse: ", valid_loss.numpy())
写回答

1回答

正十七

2020-05-08

  1. 去掉了tf.reduce_mean是因为多余,因为mean_square_error已经算了平均。

  2. tf.squeeze是删减了值为1的维度,这是因为model(x_valid_scaled)得到的结果是[num_examples, 1],跟y_pred的维度(num_examples, ),多一个维度。

0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程

相似问题

pass

回答 1

pass

回答 1

pass

回答 1

pass

回答 1

pass

回答 1