热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorchCNNCIFAR10数据集识别

尝试使用深层结构进行CIFAR10的识别importtorchimporttorchvisionimporttorchvision.transformsastransform

尝试使用深层结构进行CIFAR10的识别

import torch
import torchvision
import torchvision.transforms as transformsBATCH_SIZE = 64
EPOCHES = 50
NUM_WORKERS = 4
LEARNING_RATE = 0.005# 数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练数据和测试数据
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,shuffle=True, num_workers=NUM_WORKERS)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS)# 类别标签
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified

下面定义网络

import torch.nn as nn
import torch.nn.functional as F# 参考https://www.jianshu.com/p/016a23bc6554
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.conv3 = nn.Conv2d(32, 64, 3, padding=1)self.conv4 = nn.Conv2d(64, 128, 3, padding=1)self.conv5 = nn.Conv2d(128, 256, 3, padding=1)self.MaxPool = nn.MaxPool2d(2, 2)self.AvgPool = nn.AvgPool2d(4, 4)self.fc1 = nn.Linear(256, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 32)self.fc4 = nn.Linear(32, 10)def forward(self, x):x = F.relu(self.conv1(x)) # (3,32,32) -> (16,32,32)x = self.MaxPool(F.relu(self.conv2(x))) # (16,32,32) -> (32,16,16)x = F.relu(self.conv3(x)) # (32,16,16) -> (64,16,16)x = self.MaxPool(F.relu(self.conv4(x))) # (64,16,16) -> (128,8,8)x = self.MaxPool(F.relu(self.conv5(x))) # (128,8,8) -> (256,4,4)x = self.AvgPool(x) # (256,1,1)x = x.view(-1, 256) # (256)x = self.fc3(self.fc2(self.fc1(x))) # (32)x = self.fc4(x) # (10)return xnet = Net()
if torch.cuda.is_available():net = net.cuda()
print(net)

Net((conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(MaxPool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(AvgPool): AvgPool2d(kernel_size=4, stride=4, padding=0)(fc1): Linear(in_features=256, out_features=128, bias=True)(fc2): Linear(in_features=128, out_features=64, bias=True)(fc3): Linear(in_features=64, out_features=32, bias=True)(fc4): Linear(in_features=32, out_features=10, bias=True)
)

我们先看看一个batch的数据

import matplotlib.pyplot as plt
import numpy as npdataiter = iter(trainloader)
image, label = dataiter.next()
print(image.shape)
print(label.shape)def imshow(img):img = img / 2 + 0.5 # 反标准化 npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()imshow(torchvision.utils.make_grid(image))# 打印标签,一行八个,打印八行
for i in range(8):print(" ".join("%5s" % classes[label[i*8+j]] for j in range(8)))

torch.Size([64, 3, 32, 32])
torch.Size([64])

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7dfZrbuJ-1594882315368)(output_5_1.png)]

bird dog truck cat car dog horse shipdeer ship dog bird car cat plane deerdeer deer car truck plane dog deer shipbird bird horse truck truck ship deer dog
truck frog plane car bird cat car planeship horse plane truck car deer horse ship
horse truck ship ship dog dog deer ship
plane frog dog bird plane bird ship ship

使用交叉熵作为损失函数,Adam函数为优化函数

import torch.optim as optim
from torch.autograd import Variable
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)

定义训练函数和测试函数

def train(path):losses = []acces = []test_acc = []print("------train start------")for epoch in range(1, EPOCHES+1):train_loss = 0train_acc = 0for i, data in enumerate(trainloader, 0):inputs, labels = dataif torch.cuda.is_available():inputs = Variable(inputs).cuda()labels = Variable(labels).cuda()else:inputs = Variable(inputs)labels = Variable(labels)# 前向传播optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 计算损失值train_loss += loss.item()# 计算训练的时候得到的准确率_, pred = outputs.max(1)num_correct = (pred==labels).sum().item()acc = num_correct/BATCH_SIZEtrain_acc += acc# 一轮训练完了,添加数据losses.append(train_loss/len(trainloader))acces.append(train_acc/len(trainloader))test_acc.append(test(path))print("epoch: {}, Train Loss: {:.6f}, Train acc: {:.6f}, Test acc: {:.6f}".format(epoch, losses[epoch-1], acces[epoch-1], test_acc[epoch-1]))torch.save(net.state_dict(), path)print('Finished Training')return losses, acces, test_acc

def test(path):print("------test start------")test_acc = 0for i, data in enumerate(testloader, 0):inputs, labels = dataif torch.cuda.is_available():inputs = Variable(inputs).cuda()labels = Variable(labels).cuda()else:inputs = Variable(inputs)labels = Variable(labels)# 前向传播optimizer.zero_grad()outputs = net(inputs)# 计算准确率_, pred = outputs.max(1)acc = (pred==labels).sum().item()/BATCH_SIZEtest_acc += accprint("test acc: {:.6f}".format(test_acc/len(testloader)))return test_acc/len(testloader)

losses = []
acces = []
test_acc = []
path = "CIFAR10_Deep.pth"
losses, acces, test_acc = train(path)

------train start------
------test start------
test acc: 0.346338
epoch: 1, Train Loss: 1.823356, Train acc: 0.271439, Test acc: 0.346338
------test start------
test acc: 0.471537
epoch: 2, Train Loss: 1.486024, Train acc: 0.434043, Test acc: 0.471537
------test start------
test acc: 0.532245
epoch: 3, Train Loss: 1.279489, Train acc: 0.527354, Test acc: 0.532245
------test start------
test acc: 0.568969
epoch: 4, Train Loss: 1.168447, Train acc: 0.574628, Test acc: 0.568969
------test start------
test acc: 0.604100
epoch: 5, Train Loss: 1.096343, Train acc: 0.603201, Test acc: 0.604100
------test start------
test acc: 0.615844
epoch: 6, Train Loss: 1.052584, Train acc: 0.624181, Test acc: 0.615844
------test start------
test acc: 0.609375
epoch: 7, Train Loss: 1.002270, Train acc: 0.639466, Test acc: 0.609375
------test start------
test acc: 0.627986
epoch: 8, Train Loss: 0.965407, Train acc: 0.653613, Test acc: 0.627986
------test start------
test acc: 0.648487
epoch: 9, Train Loss: 0.935815, Train acc: 0.664762, Test acc: 0.648487
------test start------
test acc: 0.647492
epoch: 10, Train Loss: 0.898331, Train acc: 0.679208, Test acc: 0.647492
------test start------
test acc: 0.659136
epoch: 11, Train Loss: 0.866001, Train acc: 0.691196, Test acc: 0.659136
------test start------
test acc: 0.667396
epoch: 12, Train Loss: 0.836056, Train acc: 0.704364, Test acc: 0.667396
------test start------
test acc: 0.670482
epoch: 13, Train Loss: 0.810909, Train acc: 0.712036, Test acc: 0.670482
------test start------
test acc: 0.661027
epoch: 14, Train Loss: 0.785331, Train acc: 0.721947, Test acc: 0.661027
------test start------
test acc: 0.676254
epoch: 15, Train Loss: 0.752643, Train acc: 0.733696, Test acc: 0.676254
------test start------
test acc: 0.667297
epoch: 16, Train Loss: 0.720532, Train acc: 0.743606, Test acc: 0.667297
------test start------
test acc: 0.673069
epoch: 17, Train Loss: 0.702551, Train acc: 0.750959, Test acc: 0.673069
------test start------
test acc: 0.678344
epoch: 18, Train Loss: 0.663924, Train acc: 0.763807, Test acc: 0.678344
------test start------
test acc: 0.671477
epoch: 19, Train Loss: 0.648946, Train acc: 0.770480, Test acc: 0.671477
------test start------
test acc: 0.671079
epoch: 20, Train Loss: 0.625400, Train acc: 0.777873, Test acc: 0.671079
------test start------
test acc: 0.668690
epoch: 21, Train Loss: 0.591782, Train acc: 0.789502, Test acc: 0.668690
------test start------
test acc: 0.678045
epoch: 22, Train Loss: 0.569181, Train acc: 0.796875, Test acc: 0.678045
------test start------
test acc: 0.667496
epoch: 23, Train Loss: 0.545593, Train acc: 0.804867, Test acc: 0.667496
------test start------
test acc: 0.648288
epoch: 24, Train Loss: 0.521233, Train acc: 0.814218, Test acc: 0.648288
------test start------
test acc: 0.660231
epoch: 25, Train Loss: 0.500628, Train acc: 0.821032, Test acc: 0.660231
------test start------
test acc: 0.671576
epoch: 26, Train Loss: 0.495962, Train acc: 0.823030, Test acc: 0.671576
------test start------
test acc: 0.668193
epoch: 27, Train Loss: 0.466635, Train acc: 0.833300, Test acc: 0.668193
------test start------
test acc: 0.646198
epoch: 28, Train Loss: 0.445843, Train acc: 0.839934, Test acc: 0.646198
------test start------
test acc: 0.656449
epoch: 29, Train Loss: 0.434789, Train acc: 0.845269, Test acc: 0.656449
------test start------
test acc: 0.659236
epoch: 30, Train Loss: 0.399209, Train acc: 0.859055, Test acc: 0.659236
------test start------
test acc: 0.659236
epoch: 31, Train Loss: 0.405278, Train acc: 0.855319, Test acc: 0.659236
------test start------
test acc: 0.664013
epoch: 32, Train Loss: 0.383631, Train acc: 0.864110, Test acc: 0.664013
------test start------
test acc: 0.639829
epoch: 33, Train Loss: 0.368625, Train acc: 0.869266, Test acc: 0.639829
------test start------
test acc: 0.654956
epoch: 34, Train Loss: 0.376511, Train acc: 0.865589, Test acc: 0.654956
------test start------
test acc: 0.654260
epoch: 35, Train Loss: 0.337547, Train acc: 0.878976, Test acc: 0.654260
------test start------
test acc: 0.651174
epoch: 36, Train Loss: 0.354066, Train acc: 0.873481, Test acc: 0.651174
------test start------
test acc: 0.627986
epoch: 37, Train Loss: 0.321658, Train acc: 0.884531, Test acc: 0.627986
------test start------
test acc: 0.652269
epoch: 38, Train Loss: 0.336692, Train acc: 0.881694, Test acc: 0.652269
------test start------
test acc: 0.630573
epoch: 39, Train Loss: 0.303990, Train acc: 0.892363, Test acc: 0.630573
------test start------
test acc: 0.653463
epoch: 40, Train Loss: 0.306856, Train acc: 0.891824, Test acc: 0.653463
------test start------
test acc: 0.656748
epoch: 41, Train Loss: 0.302821, Train acc: 0.893143, Test acc: 0.656748
------test start------
test acc: 0.654061
epoch: 42, Train Loss: 0.292990, Train acc: 0.896739, Test acc: 0.654061
------test start------
test acc: 0.656250
epoch: 43, Train Loss: 0.284619, Train acc: 0.899157, Test acc: 0.656250
------test start------
test acc: 0.653762
epoch: 44, Train Loss: 0.301842, Train acc: 0.894581, Test acc: 0.653762
------test start------
test acc: 0.647393
epoch: 45, Train Loss: 0.249959, Train acc: 0.913183, Test acc: 0.647393
------test start------
test acc: 0.646994
epoch: 46, Train Loss: 0.288888, Train acc: 0.899177, Test acc: 0.646994
------test start------
test acc: 0.633459
epoch: 47, Train Loss: 0.275196, Train acc: 0.904572, Test acc: 0.633459
------test start------
test acc: 0.650577
epoch: 48, Train Loss: 0.263865, Train acc: 0.908148, Test acc: 0.650577
------test start------
test acc: 0.640326
epoch: 49, Train Loss: 0.264790, Train acc: 0.907809, Test acc: 0.640326
------test start------
test acc: 0.649781
epoch: 50, Train Loss: 0.236096, Train acc: 0.917439, Test acc: 0.649781
Finished Training

x = np.arange(1, 1+EPOCHES)
plt.plot(x, losses)
plt.title("train losses")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.grid()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bhqzedda-1594882315376)(output_12_0.png)]

plt.plot(x, acces, label="train acc")
plt.plot(x, test_acc, label="test acc")
plt.title("accuracy")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.grid()
plt.legend()
plt.show()

在这里插入图片描述

可以发现,虽然训练准确率很高,但是在测试集上的准确率一直徘徊在0.65左右上升不了了
我们拿测试集中的8个数据测试一下结果

dataiter = iter(testloader)
image, label = dataiter.next()
print(image.shape)
print(label.shape)
imshow(torchvision.utils.make_grid(image[:8]))image = Variable(image).cuda()
output = net(image)
_,pred = output.max(1)
print("lables: "+" ".join(classes[label[i]]for i in range(8)))
print("prediction: "+" ".join(classes[pred[i]] for i in range(8)))

torch.Size([64, 3, 32, 32])
torch.Size([64])

在这里插入图片描述

lables: cat ship ship plane frog frog car frog
prediction: dog car ship ship frog frog car bird


推荐阅读
  • 本文介绍了如何使用n3-charts绘制以日期为x轴的数据,并提供了相应的代码示例。通过设置x轴的类型为日期,可以实现对日期数据的正确显示和处理。同时,还介绍了如何设置y轴的类型和其他相关参数。通过本文的学习,读者可以掌握使用n3-charts绘制日期数据的方法。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • MyBatis多表查询与动态SQL使用
    本文介绍了MyBatis多表查询与动态SQL的使用方法,包括一对一查询和一对多查询。同时还介绍了动态SQL的使用,包括if标签、trim标签、where标签、set标签和foreach标签的用法。文章还提供了相关的配置信息和示例代码。 ... [详细]
  • Gitlab接入公司内部单点登录的安装和配置教程
    本文介绍了如何将公司内部的Gitlab系统接入单点登录服务,并提供了安装和配置的详细教程。通过使用oauth2协议,将原有的各子系统的独立登录统一迁移至单点登录。文章包括Gitlab的安装环境、版本号、编辑配置文件的步骤,并解决了在迁移过程中可能遇到的问题。 ... [详细]
  • 本文介绍了一个适用于PHP应用快速接入TRX和TRC20数字资产的开发包,该开发包支持使用自有Tron区块链节点的应用场景,也支持基于Tron官方公共API服务的轻量级部署场景。提供的功能包括生成地址、验证地址、查询余额、交易转账、查询最新区块和查询交易信息等。详细信息可参考tron-php的Github地址:https://github.com/Fenguoz/tron-php。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • 本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。 ... [详细]
  • 本文讨论了如何使用GStreamer来删除H264格式视频文件中的中间部分,而不需要进行重编码。作者提出了使用gst_element_seek(...)函数来实现这个目标的思路,并提到遇到了一个解决不了的BUG。文章还列举了8个解决方案,希望能够得到更好的思路。 ... [详细]
  • OpenMap教程4 – 图层概述
    本文介绍了OpenMap教程4中关于地图图层的内容,包括将ShapeLayer添加到MapBean中的方法,OpenMap支持的图层类型以及使用BufferedLayer创建图像的MapBean。此外,还介绍了Layer背景标志的作用和OMGraphicHandlerLayer的基础层类。 ... [详细]
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • 本文整理了Java中org.apache.solr.common.SolrDocument.setField()方法的一些代码示例,展示了SolrDocum ... [详细]
  • 本文整理了常用的CSS属性及用法,包括背景属性、边框属性、尺寸属性、可伸缩框属性、字体属性和文本属性等,方便开发者查阅和使用。 ... [详细]
  • 我用Tkinter制作了一个图形用户界面,有两个主按钮:“开始”和“停止”。请您就如何使用“停止”按钮终止“开始”按钮为以下代码调用的已运行功能提供建议 ... [详细]
  • 引号快捷键_首选项和设置——自定义快捷键
    3.3自定义快捷键(CustomizingHotkeys)ChemDraw快捷键由一个XML文件定义,我们可以根据自己的需要, ... [详细]
author-avatar
挖掘机销售mv
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有