为什么下载mnist数据集总是失败?
来源:4-5 利用神经网络解决分类和回归问题(3)

KingCoder
2021-03-19
import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
# data
# 加载手写数字识别所需数据集,并且定义存储路径,将数据集转换为Tensor,并对数据集进行下载
train_data = dataset.MNIST(root="mnist",
train=True,
transform=transforms.ToTensor(),
download=True)
test_data = dataset.MNIST(root="mnist",
train=False,
transform=transforms.ToTensor(),
download=False)
# batch size(训练数据较多)
# 对数据进行分批提取
# 定义数据集、batch大小、打乱数据排列
train_loader = data_utils.DataLoader(dataset=train_data,
batch_size=64,
shuffle=True)
test_loader = data_utils.DataLoader(dataset=test_data,
batch_size=64,
shuffle=True)
# net(一个卷积层和一个FC层)
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.fc = torch.nn.Linear(14 * 14 * 32, 10)
def forward(self, x):
out = self.conv(x)
out = out.view(out.size()[0], -1)
out = self.fc(out)
return out
cnn = CNN()
# loss(交叉熵损失函数)
loss_func = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.01)
# training
for epoch in range(10):
for i, (images, labels) in enumerate(train_loader):
# images = images.cuda()
# labels = labels.cuda()
outputs = cnn(images)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch is {}, ite is "
"{}/{}, loss is {}".format(epoch + 1, i,
len(train_data) // 64,
loss.item()))
# eval/test
loss_test = 0
accuracy = 0
for i, (images, labels) in enumerate(test_loader):
# images = images.cuda()
# labels = labels.cuda()
outputs = cnn(images)
# [batchsize]
# outputs = batchsize * cls_num
loss_test += loss_func(outputs, labels)
_, pred = outputs.max(1)
accuracy += (pred == labels).sum().item()
accuracy = accuracy / len(test_data)
loss_test = loss_test / (len(test_data) // 64)
print("epoch is {}, accuracy is {}, "
"loss test is {}".format(epoch + 1,
accuracy,
loss_test.item()))
torch.save(cnn, "model/mnist_model.pkl")
写回答
1回答
-
会写代码的好厨师
2021-04-10
看样子是官网挂了。
用这个吧
https://pan.baidu.com/s/1pLcpsk7
00
相似问题
为什么下载mnist数据集总是失败?
回答 1
为什么下载mnist数据集总是失败?
回答 1
为什么下载mnist数据集总是失败?
回答 1
mnist数据集下载失败?什么原因?
回答 1
bug报错不一样
回答 2