热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

龙良曲pytorch学习笔记_10

LeNet5网络和CIFAR10数据集main函数--dataloader--train--test1importtorch2fromtorch.utils.dataimportD

LeNet5网络和CIFAR10数据集

main函数--dataloader--train--test


1 import torch
2 from torch.utils.data import DataLoader
3 from torchvision import datasets
4 from torchvision import transforms
5 from torch import nn,optim
6 from lenet5 import LeNet5
7
8 def main():
9 batch_size = 32
10 cifar_train = datasets.CIFAR10(cifar,train = True,transform = transforms.Compose([
11 transforms.Resize((32,32)),
12 transforms.ToTensor()
13 ]),download = True)
14
15 # 可以同时加载多张图片
16 cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True)
17
18 cifar_test = datasets.CIFAR10(cifar,train = False,transform = transforms.Compose([
19 transforms.Resize((32,32)),
20 transforms.ToTensor()
21 ]),download = True)
22
23 # 可以同时加载多张图片
24 cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True)
25
26 # 数据加载成功后可以检验shape
27 x,label = iter(cifar_train).next()
28 print(x:,x.shape,label:,label.shape)
29
30 device = torch.device(cuda)
31 model = LeNet5().to(device)
32 criteon = nn.CrossEntropyLoss().to(device)
33 optimizer = optim.Adam(model.parameters(),lr=1e-3)
34
35 print(model)
36
37 for epoch in range(1000):
38
39 model.train()
40 for batchidx,(x,label) in enumerate(cifar_train):
41 # x: [b,3,32,32], label: [b]
42 x,label = x.to(device),label.to(device)
43
44 logits = model(x)
45 # logits:[b,10]
46 # label:[b]
47 loss = criteon(logits,label)
48
49 # backprop
50 optimizer.zero_grad()
51 loss.backwark()
52 optimizer.step()
53
54 #
55 print(epoch,loss.item())
56
57 model.eval()
58 # 不需要做梯度相关计算
59 with torch.nn_grad():
60 # test
61 total_correct = 0
62 total_num = 0
63 for x,label in cifar_test:
64 x,label = x.to(device),label.to(device)
65 # logits:[b,10]
66 logits = model(x)
67 pred = logits.argmax(dim=1)
68 # 获取一个batch的在累加
69 total_correct = += torch.eq(pred,label).float().sum().item()
70 # x.size(0)就是batch_size
71 total_num += x.size(0)
72
73 acc = total_correct / total_num
74 print(epoch,acc)
75
76 if __name__ == __main__
77 main()

LeNet网络--tmp测试


1 import torch
2 from torch import nn
3 from torch.nn import functional as F
4
5 class LeNet5(nn.Module):
6 """
7 for cifar10 dataset.
8 """
9 def __init__(self):
10 super(LeNet5,self).__init__()
11
12 self.conv_unit = nn.Sequential(
13 # x:[b,3,32,32] --> [b,6,]
14 # input_channel,output_channel,kernel_size,stride,padding
15 nn.Conv2d(3,6,kernel_size = 5,stride = 1,padding = 0),
16 nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
17 #
18 nn.Conv2d(6,16,kernel_size = 5,stride = 1,padding = 0),
19 nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
20 )
21 # Flatten
22 # fc_unit
23 self.fc_unit = nn.Sequential(
24 # 由下面的测试得出来的
25 nn.Linear(16*5*5,120),
26 # 全连接层会出现梯度离散现象,加一个relu
27 nn.ReLU(),
28 nn.Linear(120,84),
29 nn.ReLU(),
30 nn.Linear(84,10),
31 )
32 ‘‘‘
33 tmp = torch.randn(2,3,32,32)
34 out = self.conv_unit(tmp)
35 # 测试一下输出的维度,用于全连接层
36 # [2,16,5,5]
37 print(‘conv_out:‘,out.shape)
38 ‘‘‘
39
40 # use Cross Entropy Loss
41 # 放到类外,不用引入y参数
42 # self.criteon = nn.CrossEntropyLoss()
43
44 # 从左往右走的,backward会自动根据这个走
45 def forward(self,x):
46 # 取得x的shape,然后0号为batch_size
47 batch_size = x.size(0)
48 # [b,3,32,32] --> [b,16,5,5]
49 x = self.conv_unit(x)
50 # [b,16,5,5] --> [b,16*5*5]
51 x = x.view(batch_size,16*5*5)
52 # [b,16*5*5] --> [b,10]
53 logits = self.fc_unit(x)
54 return logits
55 # [b,10] crossEntropy会包含,不用写
56 # pred = F.softmax(logits,dim = 1)
57 # loss = self.criteon(logits,y)
58
59
60 def main():
61
62 net = LeNet5()
63 tmp = torch.randn(2,3,32,32)
64 out = net(tmp)
65 print(lenet_out:,out.shape)
66
67
68 if __name__ == __main__
69 main()

 


推荐阅读
  • npmimportuse这里我记录一下,视频地址和封面地址均引用的是服务器端得,本地的视频和图片 ... [详细]
  • Xib九宫格应用管理使用xib封装一个自定义view的步骤1新建一个继承UIView的自定义view,假设类名叫做(AppView)2新建一个AppView.xib文件来描述 ... [详细]
  • iOS之富文本
    之前做项目时遇到一个问题:使用UITextView显示一段电影的简介,由于字数比较多,所以字体设置的很小,行间距和段间距也很小,一大段文字挤在一起看起来很别扭,想要把行间距调大,结 ... [详细]
  • 目录结构如下:Nginx基础知识NginxHTTP服务器的特色及优点Nginx的主要企业功能Nginx作为web服务器的主要应用场景包括:Nginx的安装安装环境 ... [详细]
  • 《每个设计师都应该掌握的50个css代码段》11~20段
    2019独角兽企业重金招聘Python工程师标准11.胶卷边框img.polaroid{background:#000;*Changethistoabackgroundima ... [详细]
  • ARToolKitunity
    ARToolKit为开源的AR库,相对于高通和easyAr有几点特点:1)开源2)识别项目可以动态添加(详细在后)3)识别文件可以本地生成4)目前只能识别图片(目前为.jpg格式) ... [详细]
  • linux文件系统和挂载
    创建ISO文件cpdevcdrom目的地.isomkfs命令生成对应·的文件系统但是使用mkfs没有办法修该生成的系统文件的某些特性,例如标记LABEL,如果强行修改会导致文件里面 ... [详细]
  • salesforce lightning零基础学习(七) 列表展示数据时两种自定义编辑页面
    上一篇Lightning内容描述的是LDS,通过LDS可以很方便的实例化一个对象的数据信息。当我们通过列表展示数据需要编辑时,我们常使用两种方式去处理编辑页面:PopUpWindo ... [详细]
  • 开发网站你需要知晓的部分专用术语
      越来越多的企业和个人都在拥有属于自己的网站门户,首当其冲的就是你得知晓几个网站方面的专业术语,先是中就有好多的客户不明白这些,造成误会是正常的,那不如我们对它有个大致的了解,这样就不容易感觉 ... [详细]
  • 2019.4.14第1001题:SumProblemProblemDescriptionHey,welcometoHDOJ(HangzhouDianziUniversityOnli ... [详细]
  • 这一篇主要总结一下jQuery这个js在引入的时候做的一些初始化工作第一句window.undefinedwindow.undefined;是为了兼容低版本的IE而写的因为在低版本 ... [详细]
  • spotify engineering culture part 1
    原文,因为原视频说的太快太长,又没有字幕,于是借助youtube,把原文听&打出来了。中文版日后有时间再翻译。oneofthebigsucceessfactorshereatSpo ... [详细]
  • MyBatis模糊查询和多条件查询一、ISmbmsUserDao层根据姓名模糊查询publicListgetUser();多条件查询publicList ... [详细]
  • vscode里的html标签导航的一系列问题
    哈喽,我今天带来的经验是,vscode在18年10月更新后的1.29以后,编辑html文档时,会发现最上面有个类似于HTML标签导航的玩意儿,可能部分同学和我一样不习惯用它们,现在 ... [详细]
  • NickLa制作了另伟大的教程。NickLa向我们展示了如何装饰,而无需编辑源图像的图像和照片画廊。诀窍是很简单。所有你需要的是一个额外的标签和应用背景图像创建的叠加 ... [详细]
author-avatar
书友73892718
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有