_, pred = output.max(1)

来源:4-7 利用神经网络解决分类和回归问题(5)

谢思阳

2020-12-06

这里没看懂啊,这里是什么意思呢?

写回答

1回答

荼灬

2020-12-09

首先,这是一个分类任务,也就是分成数字0-9十类, 那么对于一张图片其输出也就有十种可能。分类网络的任务就是计算这张图是十类中的每一类的概率,然后再选出最大概率的索引值表示类别。比如一张图片的经过网络的输出结果是:[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0],  索引值为3的时候概率最大,为0.7, 那么网络就会预测这张图是3

 至 _,pred = output.max(1) ,  就是取最大索引的操作。1实际上是表示取第一维度上的最大索引, 为什么还要有个第一维度呢?是因为我们是按batch_size进行计算的,假设batch_size = 3(也就是对三张图片分类,那就有三个类似[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0]的结果):

[

[0.1, 0.1, 0.1, 0.7, 0, 0, 0, 0, 0, 0],

[0.1, 0.9, 0, 0, 0, 0, 0, 0, 0, 0],

[1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],

]

 _,pred = output.max(1)  的结果是 [3, 1, 0],其意思就是对每一张图片的预测结果求最大索引,给定维度1就是对这个batch_size里面的每一个[1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0]求最大索引


2
2
荼灬
回复
我的可乐说
_ 作为变量名,表示无关紧要的变量。一般不会用到的变量可以用这个来表示
2021-01-24
共2条回复

PyTorch入门到进阶 实战计算机视觉与自然语言处理项目

理论基础+技术讲解+实战开发,快速掌握PyTorch框架

1190 学习 · 293 问题

查看课程