作者:xinyaolin_857 | 来源:互联网 | 2023-08-21 21:33
手写数字识别是深度学习界的“HELLO WORLD”。网上代码很多,找一份自己读懂,对整个学习网络理解会有帮助。
这是自动下载mnist数据集的,如果下载过慢,可以自己把mnist数据集下载好放到相应目录下面
MNIST数据集 提取码:1234
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
batch_size = 64train_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, 5)self.conv3 = nn.Conv2d(20, 40, 3)self.mp = nn.MaxPool2d(2)self.fc = nn.Linear(40, 10)def forward(self, x):in_size = x.size(0) x = F.relu(self.mp(self.conv1(x)))x = F.relu(self.mp(self.conv2(x)))x = F.relu(self.mp(self.conv3(x)))x = x.view(in_size, -1) x = self.fc(x)return F.log_softmax(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)def train(epoch):for batch_idx, (data, target) in enumerate(train_loader):output = model(data)loss = F.nll_loss(output, target)if batch_idx % 200 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data.item()))optimizer.zero_grad() loss.backward() optimizer.step() def test():test_loss = 0correct = 0for data, target in test_loader:data, target = Variable(data, volatile=True), Variable(target)output = model(data)test_loss += F.nll_loss(output, target, size_average=False).data.item()pred = output.data.max(1, keepdim=True)[1]print(pred)correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))for epoch in range(1, 10):train(epoch)test()