热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

AlexNet基于MNIST数据集的代码实现

AlexNet基于MNIST数据集的代码实现鉴于原论文中使用的数据集过于庞大,分类过多,目前手头的设备运行是在过于缓慢,折中考虑尝试使用
AlexNet基于MNIST数据集的代码实现

鉴于原论文中使用的数据集过于庞大,分类过多,目前手头的设备运行是在过于缓慢,折中考虑尝试使用MNIST的数据集实现AlexNet

import torch, torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
%matplotlib inline
import copy

# 超参数设置
EPOCH = 10
BATCH_SIZE = 64
LR = 0.01

transform = transforms.ToTensor()

数据集

通过torchvision下载数据集

trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)testset = torchvision.datasets.MNIST(root='../data', train=True, transform=transform)

C:\Users\Administrator\AppData\Roaming\Python\Python36\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:180.)return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

绘图查看

plt.imshow(trainset[4][0][0], cmap='gray')


请添加图片描述

查看数据格式

trainset[0][0].shape

torch.Size([1, 28, 28])

为了通用,设置一个device,如果有显卡并配置好了cuda环境,那么就选择为cuda,否则为cpu

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

AlexNet

同样地,仿照AlexNet,设置了五个卷积层和三个全连接层构建一个深度卷积神经网络,网络的定义是重写nn.Module实现的,卷积层和全连接层之间将数据通过view拉平[1]

class AlexNet(nn.Module):def __init__(self,width_mult=1):super(AlexNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1), # 32*28*28nn.MaxPool2d(kernel_size=2, stride=2), # 32*14*14nn.ReLU(inplace=True),)self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), # 64*14*14nn.MaxPool2d(kernel_size=2, stride=2), # 64*7*7nn.ReLU(inplace=True),)self.layer3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1), # 128*7*7)self.layer4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1), # 256*7*7)self.layer5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), # 256*7*7nn.MaxPool2d(kernel_size=3, stride=2), # 256*3*3nn.ReLU(inplace=True),)self.fc1 = nn.Linear(256*3*3, 1024)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 10)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.layer5(x)x = x.view(-1, 256*3*3)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x

设置超参数

EPOCH = 5
BATCH_SIZE = 128
LR = 0.01

def validate(model, data):total = 0correct = 0for i, (images, labels) in enumerate(data):images = images.to(device)x = net(images)value, pred = torch.max(x,1)pred = pred.data.cpu()total += x.size(0)correct += torch.sum(pred == labels)return correct*100./total

初始化模型并将模型放到device上,如果有显卡就在cuda上,如果没有,那么在cpu

如果是纯cpu训练,速度十分感人

net = AlexNet().to(device)

# alexnet训练
def train():# 定义损失函数为交叉熵损失,优化方法为SGDcriterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)max_accuracy=0accuracies=[]for epoch in range(EPOCH):for i, (images,labels) in enumerate(trainloader):images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss_item = loss.item()loss.backward()optimizer.step()accuracy = float(validate(criterion, testloader))accuracies.append(accuracy)print("Epoch %d accuracy: %f loss: %f" % (epoch, accuracy, loss_item))if accuracy > max_accuracy:best_model = copy.deepcopy(criterion)max_accuracy = accuracyprint("Saving Best Model with Accuracy: ", accuracy)print('Epoch:', epoch+1, "Accuracy :", accuracy, '%')plt.plot(accuracies)return best_model

这一行代码是调用之前的train函数训练神经网络,初始化设置的epoch是5,大概也可以训练一个准确度较高的模型

alexnet = train()

为了防止断点或者bug导致jupyter重启之后重新训练模型,这一点经常遇到,本代码是在google的colab上训练的,为了保存训练的结果,还是将模型保存为pkl文件,这样本地就不用训练,直接调用训练之后的模型,之前尝试直接保存整个模型,但是会有莫名其妙的bug,暂时没有解决。这里尝试了另一种保存模型的方式[2],直接保存模型的参数,然后将参数传递到初始化的模型架构中,如下所示:

# 保存模型参数
torch.save(alexnet, '../models/alexnet.pkl')

# 加载模型
alexnet = AlexNet()
alexnet.load_state_dict(torch.load('../models/alexnet.pkl'))

AlexNet((layer1): Sequential((0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): ReLU(inplace=True))(layer2): Sequential((0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): ReLU(inplace=True))(layer3): Sequential((0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(layer4): Sequential((0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(layer5): Sequential((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(2): ReLU(inplace=True))(fc1): Linear(in_features=2304, out_features=1024, bias=True)(fc2): Linear(in_features=1024, out_features=512, bias=True)(fc3): Linear(in_features=512, out_features=10, bias=True)
)

为直观的查看效果,选择一组测试集图片查看分类效果

plt.figure(figsize=(14, 14))
for i, (image, label) in enumerate(testloader):predict = torch.argmax(alexnet(image), axis=1)print((predict == label).sum()/label.shape[0])for j in range(image.shape[0]):plt.subplot(8, 8, j+1)plt.imshow(image[j, 0], cmap='gray')plt.title(predict[j].item())plt.axis('off')if i==1:break

tensor(1.)

请添加图片描述

参考文献



  • [1] Sowndharya206/alexnet


  • [2] SAVE AND LOAD THE MODEL

推荐阅读
  • pytorch Dropout过拟合的操作
    这篇文章主要介绍了pytorchDropout过拟合的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完 ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
  • 都会|可能会_###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了###haohaohao###图神经网络之神器——PyTorchGeometric上手&实战相关的知识,希望对你有一定的参考价值。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • Android开发实现的计时器功能示例
    本文分享了Android开发实现的计时器功能示例,包括效果图、布局和按钮的使用。通过使用Chronometer控件,可以实现计时器功能。该示例适用于Android平台,供开发者参考。 ... [详细]
  • Go GUIlxn/walk 学习3.菜单栏和工具栏的具体实现
    本文介绍了使用Go语言的GUI库lxn/walk实现菜单栏和工具栏的具体方法,包括消息窗口的产生、文件放置动作响应和提示框的应用。部分代码来自上一篇博客和lxn/walk官方示例。文章提供了学习GUI开发的实际案例和代码示例。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • Android源码深入理解JNI技术的概述和应用
    本文介绍了Android源码中的JNI技术,包括概述和应用。JNI是Java Native Interface的缩写,是一种技术,可以实现Java程序调用Native语言写的函数,以及Native程序调用Java层的函数。在Android平台上,JNI充当了连接Java世界和Native世界的桥梁。本文通过分析Android源码中的相关文件和位置,深入探讨了JNI技术在Android开发中的重要性和应用场景。 ... [详细]
  • navicat生成er图_实践案例丨ACL2020 KBQA 基于查询图生成回答多跳复杂问题
    摘要:目前复杂问题包括两种:含约束的问题和多跳关系问题。本文对ACL2020KBQA基于查询图生成的方法来回答多跳复杂问题这一论文工作进行了解读 ... [详细]
author-avatar
捕鱼达人2602917825
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有