热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

06使用pytorch实现手写数字识别

目录1.思路和流程分析2.准备训练集和测试集2.1torchvision.transforms的图形数据处理方法2.1.1torchvison.transforms.ToT

目录

1.思路和流程分析

2.准备训练集和测试集

2.1 torchvision.transforms的图形数据处理方法

2.1.1 torchvison.transforms.ToTensor

2.1.2 torchvision.transforms.Normalize(mean,std)

2.1.3 torchvision.transforms.Compose(transforms)

2.2 准备MNIST数据集的Dataset和DataLoader

3.构建模型

3.1 激活函数的使用

3.2 模型中数据的形状(【添加形状变化图形】)

3.3 模型的损失函数

4.模型的训练

5.模型的保存和加载

5.1 模型的保存

5.2 模型的加载

6.模型的评估

7.总的代码




1.思路和流程分析


2.准备训练集和测试集


2.1 torchvision.transforms的图形数据处理方法


2.1.1 torchvison.transforms.ToTensor

from torchvision.datasets import MNIST
mnist=MNIST(root=r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\datas',train=True,download=True,transform=None)#len(mnist)==60000
print(mnist[0])#(, 5)
img=mnist[0][0]
img.show()#打开图片

from torchvision import transforms
import numpy as np
data=np.random.randint(0,255,size=12)
img=data.reshape(2,2,3)
print(img.shape)
img_tensor=transforms.ToTensor()(img)#转换成tensor
print(img_tensor)
print(img_tensor.size())

输出如下:

(2, 2, 3)
tensor([[[235, 30],[236, 92]],[[ 1, 113],[ 53, 5]],[[ 21, 190],[ 46, 11]]], dtype=torch.int32)
torch.Size([3, 2, 2])


2.1.2 torchvision.transforms.Normalize(mean,std)

from torchvision import transforms
import numpy as np
import torchvision
data=np.random.randint(0,255,size=12)
img=data.reshape(2,2,3)
img=transforms.ToTensor()(img)#转换成tensor
print(img)
print('*'*100)
norm_img=transforms.Normalize((10,10,10),(1,1,1))(img)#进行规范化处理
print(norm_img)

2.1.3 torchvision.transforms.Compose(transforms)

transforms.Compose([torchvision.transforms.ToTensor(),#先转换为Tensortorchvision.transforms.Normalize(mean,std)#再进行正则化
])

2.2 准备MNIST数据集的Dataset和DataLoader

准备训练集:

from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoaderBATCH_SIZE=128
#1.准备数据
def get_dataloader(train=True):transform_fn = Compose([ToTensor(),Normalize(mean=(0.1307,), std=(0.3081,)) # mean std的形状和通道数相同])dataset = MNIST(root=r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\datas', train=True, transform=transform_fn)data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)return data_loader

3.构建模型


3.1 激活函数的使用

import torch
import torch.nn.functional as F
b=torch.Tensor([-2,-1,0,1,2])
print(F.relu(b)) #tensor([0., 0., 0., 1., 2.])

3.2 模型中数据的形状(【添加形状变化图形】)

import torch.nn as nn
import torch.nn.functional as F#2.构建模型
class MnistNodel(nn.Module):def __init__(self):super(MnistNodel,self).__init__()self.fc1=nn.Linear(1*28*28,28)#第一个全连接self.fc2=nn.Linear(28,10)#第二个全连接 最终有10个类别def forward(self,input):""":param input:[batch_size,1,28,28]:return:输出层"""#1.修改形状x=input.view([input.size(0),1*28*28])#或者input.view([-1,1*28*28])#2,进行全连接的操作x=self.fc1(x)#3.进行激活函数的处理x=F.relu(x)#形状无变化#4.输出层out=self.fc2(x)return out

3.3 模型的损失函数

#方法一
criterion=nn.CrossEntropyLoss()#交叉熵损失
loss=criterion(input,target)#方法二
output=F.log_softmax(x,dim=-1)#1.对输出值计算softmax和取对数
loss=F.nll_loss(output,target)#2.使用torch中带权损失nll_loss


4.模型的训练

from torch.optim import Adammodel=MnistNodel()#实例化模型
optimizer=Adam(model.parameters(),lr=0.001)
def train(epoch):'''实现训练的过程'''data_loader=get_dataloader()for idx,(input,target) in enumerate(data_loader):optimizer.zero_grad()output=model(input)#调用模型,得到预测值loss=F.nll_loss(output,target)#得到损失loss.backward()#反向传播optimizer.step()#梯度的更新if idx%100==0:print(epoch,idx,loss.item(),sep='\t')

5.模型的保存和加载


5.1 模型的保存

torch.save(model.state_dict(),'path')#保存模型参数
torch.save(optimizer.state_dict(),'path')#保存优化器参数

5.2 模型的加载

model.load_state_dict(torch.load('path'))
optimizer.load_state_dict(torch.load('path'))

6.模型的评估

import numpy as npdef test():loss_list=[]acc_list=[]test_dataloader=get_dataloader(train=False)for idx,(input,target) in enumerate(test_dataloader):with torch.no_grad():output=model(input)cur_loss=F.nll_loss(output,target)loss_list.append(cur_loss)#计算准确率#output [batch_size,10] target:[batch_size]pred=output.max(dim=-1)[-1]cur_acc=pred.eq(target).float().mean()acc_list.append(cur_acc)print('平均准确率:',np.mean(acc_list),'\t平均损失:',np.mean(loss_list))#结果如下:
# 平均准确率: 0.9503709 平均损失: 0.17310049

7.总的代码

import torch,os
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamBATCH_SIZE=128
#1.准备数据
def get_dataloader(train=True):transform_fn = Compose([ToTensor(),Normalize(mean=(0.1307,), std=(0.3081,)) # mean std的形状和通道数相同])dataset = MNIST(root=r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\datas', train=True, transform=transform_fn)data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)return data_loader#2.构建模型
class MnistNodel(nn.Module):def __init__(self):super(MnistNodel,self).__init__()self.fc1=nn.Linear(1*28*28,28)#第一个全连接self.fc2=nn.Linear(28,10)#第二个全连接 最终有10个类别def forward(self,input):""":param input:[batch_size,1,28,28]:return:输出层"""#1.修改形状x=input.view([input.size(0),1*28*28])#或者input.view([-1,1*28*28])#2,进行全连接的操作x=self.fc1(x)#3.进行激活函数的处理x=F.relu(x)#形状无变化#4.输出层out=self.fc2(x)return F.log_softmax(out)model=MnistNodel()#实例化模型
if os.path.exists(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_model.pkl'):model.load_state_dict(torch.load(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_model.pkl'))
optimizer=Adam(model.parameters(),lr=0.001)
if os.path.exists(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_optimizer.pkl'):optimizer.load_state_dict(torch.load(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_optimizer.pkl'))def train(epoch):'''实现训练的过程'''data_loader=get_dataloader()for idx,(input,target) in enumerate(data_loader):optimizer.zero_grad()output=model(input)#调用模型,得到预测值loss=F.nll_loss(output,target)#得到损失loss.backward()#反向传播optimizer.step()#梯度的更新if idx%100==0:print(epoch,idx,loss.item(),sep='\t')#模型的保存if idx%100==0:#每隔100个保存一下torch.save(model.state_dict(),r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_model.pkl')torch.save(optimizer.state_dict(),r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_optimizer.pkl')def test():loss_list=[]acc_list=[]test_dataloader=get_dataloader(train=False)for idx,(input,target) in enumerate(test_dataloader):with torch.no_grad():output=model(input)cur_loss=F.nll_loss(output,target)loss_list.append(cur_loss)#计算准确率#output [batch_size,10] target:[batch_size]pred=output.max(dim=-1)[-1]cur_acc=pred.eq(target).float().mean()acc_list.append(cur_acc)print('平均准确率:',np.mean(acc_list),'\t平均损失:',np.mean(loss_list))if __name__ == '__main__':# for i in range(3):#训练三轮# train(i)test()


推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • Python使用Pillow包生成验证码图片的方法
    本文介绍了使用Python中的Pillow包生成验证码图片的方法。通过随机生成数字和符号,并添加干扰象素,生成一幅验证码图片。需要配置好Python环境,并安装Pillow库。代码实现包括导入Pillow包和随机模块,定义随机生成字母、数字和字体颜色的函数。 ... [详细]
  • [翻译]PyCairo指南裁剪和masking
    裁剪和masking在PyCairo指南的这个部分,我么将讨论裁剪和masking操作。裁剪裁剪就是将图形的绘制限定在一定的区域内。这样做有一些效率的因素࿰ ... [详细]
  • 一、死锁现象与递归锁进程也是有死锁的所谓死锁:是指两个或两个以上的进程或线程在执行过程中,因争夺资源而造成的一种互相等待的现象,若无外力作 ... [详细]
  • 如何实现织梦DedeCms全站伪静态
    本文介绍了如何通过修改织梦DedeCms源代码来实现全站伪静态,以提高管理和SEO效果。全站伪静态可以避免重复URL的问题,同时通过使用mod_rewrite伪静态模块和.htaccess正则表达式,可以更好地适应搜索引擎的需求。文章还提到了一些相关的技术和工具,如Ubuntu、qt编程、tomcat端口、爬虫、php request根目录等。 ... [详细]
  • 使用正则表达式爬取36Kr网站首页新闻的操作步骤和代码示例
    本文介绍了使用正则表达式来爬取36Kr网站首页所有新闻的操作步骤和代码示例。通过访问网站、查找关键词、编写代码等步骤,可以获取到网站首页的新闻数据。代码示例使用Python编写,并使用正则表达式来提取所需的数据。详细的操作步骤和代码示例可以参考本文内容。 ... [详细]
  • 本文介绍了腾讯最近开源的BERT推理模型TurboTransformers,该模型在推理速度上比PyTorch快1~4倍。TurboTransformers采用了分层设计的思想,通过简化问题和加速开发,实现了快速推理能力。同时,文章还探讨了PyTorch在中间层延迟和深度神经网络中存在的问题,并提出了合并计算的解决方案。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
  • loader资源模块加载器webpack资源模块加载webpack内部(内部loader)默认只会处理javascript文件,也就是说它会把打包过程中所有遇到的 ... [详细]
  • 读手语图像识别论文笔记2
    文章目录一、前言二、笔记1.名词解释2.流程分析上一篇快速门:读手语图像识别论文笔记1(手语识别背景和方法)一、前言一句:“做完了&#x ... [详细]
author-avatar
herozhx
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有