eval_with_sess函数应该不用x,y参数吧?
来源:5-9 TF1.0模型训练
站在你背后的
2020-03-04
eval_with_sess函数应该不用x,y参数吧?去掉后运行结果一致
def eval_with_sess(sess, accuracy, images, labels, batch_size):
eval_steps = images.shape[0] // batch_size
eval_accuracies = []
for step in range(eval_steps):
batch_data = images[step * batch_size : (step+1) * batch_size]
batch_label = labels[step * batch_size : (step+1) * batch_size]
accuracy_val = sess.run(accuracy,
feed_dict = {
x: batch_data,
y: batch_label
})
eval_accuracies.append(accuracy_val)
return np.mean(eval_accuracies)
#session:客户端(python) 与C++运行时(tensorflow图结构)中间的连接
with tf.Session() as sess:
sess.run(init)
for epoch in range(epochs):
for step in range(train_steps_per_epoch):
batch_data = x_train_scaled[
step * batch_size : (step+1) * batch_size]
batch_label = y_train[
step * batch_size : (step+1) * batch_size]
# 一次可以run多个参数列表,通过feed_dict喂进数据
loss_val, accuracy_val, _ = sess.run(
[loss, accuracy, train_op],
feed_dict = {
x: batch_data,
y: batch_label
})
print(’\r[Train] epoch: %d, step: %d, loss: %3.5f, accuracy: %2.2f’ % (
epoch, step, loss_val, accuracy_val), end="")
valid_accuracy = eval_with_sess(sess, accuracy,
x_valid_scaled, y_valid,
batch_size)
print("\t[Valid] acc: %2.2f" % (valid_accuracy))
1回答
-
正十七
2020-03-05
需要传的,你这里不需要是因为eval_with_sess跟x, y在一个文件中,如果在不同的文件中应该会出错。
00
相似问题