在运行train时提示错误,应该vggnet中fc出错了

来源:6-6 PyTorch搭建 VGGNet 实现Cifar10图像分类

慕无忌8219680

2022-07-26

提示的错误:
*File “D:\python\py\train.py”, line 27, in
outputs = net(inputs)
File “C:\soft\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, *kwargs)
File “D:\python\py\vggnet.py”, line 72, in forward
out = self.fc(out)

代码基本上按老师给的,我用win11,python3.9
import torch
import torch.nn as nn
import torch.nn.functional as F

class VGGbase(nn.Module):
def init(self):
super(VGGbase, self).init()
# 32828
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.max_pooling1 = nn.MaxPool2d(kernel_size=2,stride=2)
# 1414
self.conv2_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv2_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)
# 7
7
self.conv3_1 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv3_2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2,padding=1)
# 4*4
self.conv4_1 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv4_2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.max_pooling4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
# batchsize * 512 * 2 *2 --> batchsize * (512 * 4)
self.fc = nn.Linear(512 * 4,10)
def forward(self,x):
batchsize = x.size(0)
out = self.conv1(x)
out = self.max_pooling1(out)
out = self.conv2_1(out)
out = self.conv2_2(out)
out = self.max_pooling2(out)

    out = self.conv3_1(out)
    out = self.conv3_2(out)
    out = self.max_pooling3(out)

    out = self.conv4_1(out)
    out = self.conv4_2(out)
    out = self.max_pooling4(out)

    out =out.view(batchsize,-1)

    # bacthsize*c*h*w --> batchsize*n
    out = self.fc(out)

    out = F.log_softmax(out,dim=1)
    return out

def VGGNet():
return VGGbase()

写回答

1回答

会写代码的好厨师

2022-08-09

Fc层的参数设置和输入这一层的尺寸有关系,检查下是不是输入图片尺寸和课程里面介绍的有出入。

0
0

PyTorch入门到进阶 实战计算机视觉与自然语言处理项目

理论基础+技术讲解+实战开发,快速掌握PyTorch框架

1190 学习 · 298 问题

查看课程