关于predict 和 argmax

来源:2-8 神经网络实现(多分类逻辑斯蒂回归模型实现)

Elijahol0808

2019-06-12

老师,请问个问题!predict = tf.argmax(y_,1) 返回的是在在这十个类里最大值的索引,但是下面的判断正确 correct_prediction = tf.equal(predict,y) 这句就有点看不懂了。y是样本的分类,也就是一个一维长度为data.shape[0]的列表,每一个数对应着相对的样本的label。也就是[3,4,2,5,6,…]. 不知道理解的对不对?
用predict和y判断是否对应元素相等,一个是最大值的索引,一个是真正的label,我不是太理解这怎么判断相等?

图片描述

这么修改是否是正确的?
用经过softmax函数后得到p_y,求p_y的最大值的索引,并与y_one_hot的最大值索引作比较

写回答

1回答

正十七

2019-06-13

同学你好,在取了索引之后,predict就变成了一个一维向量了。比如我有两个样本,预测出的p_y为

[[0.1, 0.3, 0.4, 0.2], [0.5, 0.2, 0.2, 0.1]], 那么经过argmax之后就变为了[2, 0], 而这时候如果你的真实值就是一维数组(eg: [1, 0])的话,那么就可以直接tf.equal(), 如果你的真实值是one_hot编码的话,那么就应该用你的修改,但这时候,我觉得考虑不做one_hot会更直接。

1
1
Elijahol0808
非常感谢!谢谢老师
2019-06-14
共1条回复

深度学习之神经网络(CNN/RNN/GAN)算法原理+实战

深度学习算法工程师必学,深入理解深度学习核心算法CNN RNN GAN

2617 学习 · 935 问题

查看课程