_, pred = output.max(1)
来源:4-7 利用神经网络解决分类和回归问题(5)

谢思阳
2020-12-06
这里没看懂啊,这里是什么意思呢?
写回答
1回答
-
首先,这是一个分类任务,也就是分成数字0-9十类, 那么对于一张图片其输出也就有十种可能。分类网络的任务就是计算这张图是十类中的每一类的概率,然后再选出最大概率的索引值表示类别。比如一张图片的经过网络的输出结果是:[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0], 索引值为3的时候概率最大,为0.7, 那么网络就会预测这张图是3
至 _,pred = output.max(1) , 就是取最大索引的操作。1实际上是表示取第一维度上的最大索引, 为什么还要有个第一维度呢?是因为我们是按batch_size进行计算的,假设batch_size = 3(也就是对三张图片分类,那就有三个类似[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0]的结果):
[
[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0],
[0.1, 0.9, 0, 0, 0, 0, 0, 0, 0, 0],
[1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
_,pred = output.max(1) 的结果是 [3, 1, 0],其意思就是对每一张图片的预测结果求最大索引,给定维度1就是对这个batch_size里面的每一个[1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0]求最大索引
222021-01-24
相似问题