热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

一个调试通过的用pytorch实现mnist数据集数字识别

手写数字识别是深度学习界的“HELLOWORLD”。网上代码很多,找一份自己读懂,对整个学习网络理解会有帮助。这是自动下载mnist数据集的࿰

手写数字识别是深度学习界的“HELLO WORLD”。网上代码很多,找一份自己读懂,对整个学习网络理解会有帮助。
这是自动下载mnist数据集的,如果下载过慢,可以自己把mnist数据集下载好放到相应目录下面
在这里插入图片描述MNIST数据集 提取码:1234

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable# Training settings
batch_size = 64# MNIST Dataset# MNIST数据集已经集成在pytorch datasets中,可以直接调用train_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor())# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 输入1通道,输出10通道,kernel 5*5self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, 5)self.conv3 = nn.Conv2d(20, 40, 3)self.mp = nn.MaxPool2d(2)# fully connectself.fc = nn.Linear(40, 10)#(in_features, out_features)def forward(self, x):# in_size = 64in_size = x.size(0) # one batch 此时的x是包含batchsize维度为4的tensor,即(batchsize,channels,x,y),x.size(0)指batchsize的值 把batchsize的值作为网络的in_size# x: 64*1*28*28x = F.relu(self.mp(self.conv1(x)))# x: 64*10*12*12 feature map =[(28-4)/2]^2=12*12x = F.relu(self.mp(self.conv2(x)))# x: 64*20*4*4x = F.relu(self.mp(self.conv3(x)))x = x.view(in_size, -1) # flatten the tensor 相当于resharp# print(x.size())# x: 64*320x = self.fc(x)# x:64*10# print(x.size())return F.log_softmax(x) #64*10
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)def train(epoch):for batch_idx, (data, target) in enumerate(train_loader):#batch_idx是enumerate()函数自带的索引,从0开始# data.size():[64, 1, 28, 28]# target.size():[64]output = model(data)#output:64*10loss = F.nll_loss(output, target)if batch_idx % 200 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),# 100. * batch_idx / len(train_loader), loss.data[0]))100. * batch_idx / len(train_loader), loss.data.item()))optimizer.zero_grad() # 所有参数的梯度清零loss.backward() #即反向传播求梯度optimizer.step() #调用optimizer进行梯度下降更新参数def test():test_loss = 0correct = 0for data, target in test_loader:data, target = Variable(data, volatile=True), Variable(target)output = model(data)# sum up batch loss# test_loss += F.nll_loss(output, target, size_average=False).data[0]test_loss += F.nll_loss(output, target, size_average=False).data.item()# get the index of the max log-probabilitypred = output.data.max(1, keepdim=True)[1]print(pred)correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))for epoch in range(1, 10):train(epoch)test()

在这里插入图片描述


推荐阅读
author-avatar
xinyaolin_857
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有