学习代码时遇到了一个问题

来源:3-2 Tensor的基本定义

Sean_007

2021-01-07

def log_softmax(x):
    return x - x.exp().sum(-1).log().unsqueeze(-1)

def model(xb):
    return log_softmax(xb @ weights + bias)

batch_size = 64
xb = x_train[0:batch_size]
preds = model(xb)
print(f"preds[0]: {preds[0]}, preds.shape: {preds.shape}")

yb = y_train[0:batch_size]
print(yb.shape[0])
print(preds.shape)
print(preds[range(64), yb].shape)
# output:
preds[0]: tensor([-2.3929, -2.5026, -2.1153, -2.3121, -2.7292, -2.4812, -1.9851, -1.7757,
        -2.6971, -2.4762], grad_fn=<SelectBackward>), preds.shape: torch.Size([64, 10])
64
torch.Size([64, 10])
torch.Size([64])

请问最后一行里面的preds[range(64), yb]是什么操作啊?这里的preds貌似是一个tensor,可以这样用吗?

这是在复现pytorch官网教程时遇到的一个问题,教程里用这种方式实现了一个loss function。
https://pytorch.org/tutorials/beginner/nn_tutorial.html

def nll(input, target):
    return -input[range(target.shape[0]), target].mean()

loss_func = nll

感谢!

写回答

1回答

Sean_007

提问者

2021-01-08

不好意思,一个简单的矩阵索引问题,当时没看出来,已经解决了。

0
0

PyTorch入门到进阶 实战计算机视觉与自然语言处理项目

理论基础+技术讲解+实战开发,快速掌握PyTorch框架

1190 学习 · 298 问题

查看课程