热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch实例入门(1):图像分类

PyTorch的0.4版本带来了不小的变化,其中我最喜欢的是:Tensor和Variable这两个类合并了。原来nn的input是一个variable,现在可以直接用tensor。

PyTorch的0.4版本带来了不小的变化,其中我最喜欢的是:

  1. Tensor和Variable这两个类合并了。原来nn的input是一个variable,现在可以直接用tensor。这样在语法上更简洁易用,对初学者也更容易理解。
  2. Windows support。官方支持了windows,作为一个最近回归了windows的人很开心哈哈。

之前内存泄露的问题似乎也解决了,所以我又开心地从Tensorflow蹦回了PyTorch,顺便写点教程。先从最基本的开始,今天这篇文章讲怎么完成图像分类的任务。阅读前假设对神经网络和Python有一定了解。

通常我们在使用PyTorch的时候会用到两个包,一个是torch,一个是torchvision。其中torch是关于运算的包,torchvision则是打包了一些数据集,另外用torch实现了一些常见的神经网络模型,比如ResNet。

我们使用CIFAR-10作为数据集,包含了10个类别60000张图片,每张图片的大小为32×32,其中训练图片50000张,测试图片10000张。下图是一些示例

《PyTorch实例入门(1):图像分类》
《PyTorch实例入门(1):图像分类》 CIFAR-10示例

torchvision中已经打包好了这个数据集,我们不用自己下载,直接如下调用就可以了。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
# cifar-10官方提供的数据集是用numpy array存储的
# 下面这个transform会把numpy array变成torch tensor,然后把rgb值归一到[0, 1]这个区间
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 在构建数据集的时候指定transform,就会应用我们定义好的transform
# root是存储数据的文件夹,download=True指定如果数据不存在先下载数据
cifar_train = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
cifar_test = torchvision.datasets.CIFAR10(root='./data', train=False,
transform=transform)

load完了之后我们看看这两个数据集的信息

print(cifar_train):
==>
Dataset CIFAR10
Number of datapoints: 50000
Split: train
Root Location: ./data
Transforms (if any): None
Target Transforms (if any): None
print(cifar_test)
==>
Dataset CIFAR10
Number of datapoints: 10000
Split: test
Root Location: ./data
Transforms (if any): None
Target Transforms (if any): None
# 数据其实是用numpy array存储的,label是个list
print(cifar_train.train_data.shape)
print(type(cifar_train.train_labels), len(cifar_train.train_labels))
==>
(50000, 32, 32, 3)
50000
print(cifar_test.test_data.shape)
print(type(cifar_test.test_labels), len(cifar_test.test_labels))
==>
(10000, 32, 32, 3)
10000

在训练的时候我们可以自己写代码手动遍历数据集,指定batch和遍历方法,不过PyTorch提供了一个DataLoader类来方便我们完成这些操作。

trainloader = torch.utils.data.DataLoader(cifar_train, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(cifar_test, batch_size=32, shuffle=True)

现在我们来定义卷积神经网络,为了简单起见,我们使用经典的LeNet,它包含两个卷积层和三个全连接层,网络结构如图

《PyTorch实例入门(1):图像分类》
《PyTorch实例入门(1):图像分类》

在PyTorch中定义神经网络非常简单,第一步先是继承nn.module这个类,然后定义如下两个函数,一般的网络这样操作就足够了

class LeNet(nn.Module):
# 一般在__init__中定义网络需要的操作算子,比如卷积、全连接算子等等
def __init__(self):
super(LeNet, self).__init__()
# Conv2d的第一个参数是输入的channel数量,第二个是输出的channel数量,第三个是kernel size
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# 由于上一层有16个channel输出,每个feature map大小为5*5,所以全连接层的输入是16*5*5
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
# 最终有10类,所以最后一个全连接层输出数量是10
self.fc3 = nn.Linear(84, 10)
self.pool = nn.MaxPool2d(2, 2)
# forward这个函数定义了前向传播的运算,只需要像写普通的python算数运算那样就可以了
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
# 下面这步把二维特征图变为一维,这样全连接层才能处理
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

PyTorch具有自动求导功能,我们不需要自己写backward函数,所以很直观方便,写神经网络的结构就像写普通的数学运算公式一样。定义好网络之后我们就可以训练了,训练的代码也非常简单。首先,我们先构建一个网络实例。由于需要用到GPU,所以先获取device,然后再把网络的参数复制到GPU上

# 如果你没有GPU,那么可以忽略device相关的代码
device = torch.device("cuda:0")
net = LeNet().to(device)

然后我们需要定义Loss函数和优化方法,最简单的就是使用SGD了。PyTorch都预先定义好了这些东西

# optim中定义了各种各样的优化方法,包括SGD
import torch.optim as optim
# CrossEntropyLoss就是我们需要的损失函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

下面我们正式开始训练

print("Start Training...")
for epoch in range(30):
# 我们用一个变量来记录每100个batch的平均loss
loss100 = 0.0
# 我们的dataloader派上了用场
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # 注意需要复制到GPU
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
loss100 += loss.item()
if i % 100 == 99:
print('[Epoch %d, Batch %5d] loss: %.3f' %
(epoch + 1, i + 1, loss100 / 100))
loss100 = 0.0
print("Done Training!")

以上代码中,核心的代码是下面五行,我用注释解释了每一行的作用

# 首先要把梯度清零,不然PyTorch每次计算梯度会累加,不清零的话第二次算的梯度等于第一次加第二次的
optimizer.zero_grad()
# 计算前向传播的输出
outputs = net(inputs)
# 根据输出计算loss
loss = criterion(outputs, labels)
# 算完loss之后进行反向梯度传播,这个过程之后梯度会记录在变量中
loss.backward()
# 用计算的梯度去做优化
optimizer.step()

ok,训练完了之后我们来检测一下准确率,我们用训练好的模型来预测test数据集

# 构造测试的dataloader
dataiter = iter(testloader)
# 预测正确的数量和总数量
correct = 0
total = 0
# 使用torch.no_grad的话在前向传播中不记录梯度,节省内存
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
# 预测
outputs = net(images)
# 我们的网络输出的实际上是个概率分布,去最大概率的哪一项作为预测分类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))

在我这里最终训练了30个epoch之后准确率为63%。

以上就是用PyTroch做图像分类的基本步骤,是不是很简单呢?


推荐阅读
  • 深度强化学习Policy Gradient基本实现
    全文共2543个字,2张图,预计阅读时间15分钟。基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然 ... [详细]
  • 每日一书丨AI圣经《深度学习》作者斩获2018年图灵奖
    2019年3月27日——ACM宣布,深度学习之父YoshuaBengio,YannLeCun,以及GeoffreyHinton获得了2018年的图灵奖, ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • EIGRP增强内部网关路由协议协议号88IGRPEIGRP都是CISCO的私有协议.---高级距离矢量协议1、是唯一的一种LSDV的混合协议2、EIGRP拥有目前最快的网络路由收敛 ... [详细]
  • lora物联网开发教程(物联网lora特点)
    长距离星型架构,由于长距离连接性,从而减少了电池寿命。这个协议采用了阿罗哈法。在一个网状网络或者一个异步网络中,例如蜂窝网,结点必须频繁的被唤醒,来同步网络和检查消息。这种同步,大 ... [详细]
  • 在这一期的SendMessage函数应用中,我将向大家介绍如何利用消息函数来扩展树型列表(TreeView)控件的功能相信对于树型列表控件大家十分的熟悉, ... [详细]
  • 2022年Python面试题一.Python基础二.企业面试题结束语🥇🥇🥇✅作者简介:大家好我是编程IDὌ ... [详细]
  • Ithinkthishasbeenupbefore,butcouldntfindanyanswertoit.Ifitsalreadyansweredplease ... [详细]
  • 文本生成图像简要回顾 text to image synthesis
    摘要       文本生成图像作为近几年的热门研究领域,其解决的问题是从一句描述性文本生成与之对应的图片。近一周来,我通过阅读了近几年发表于顶会的近10篇论文,做出本文中对该方向的 ... [详细]
  • 图像处理(7) : 边缘检测
    边缘检测是图形图像处理、计算机视觉和机器视觉中的一个基本工具,通常用于特征提取和特征检测,旨在检测一张数字图像中有明显变化的边缘或者不连续的区域 ... [详细]
  • 让日期区间更友好!把常见的日期格式如:YYYY-MM-DD转换成一种更易读的格式。易读格式应该是用月份名称代替月份数字,用序数词代替数字来 ... [详细]
  • AI 学习路线:从Python开始机器学习
    AI 学习路线:从Python开始机器学习 ... [详细]
  • PNG在IE6下透明问题的解决办法
    2019独角兽企业重金招聘Python工程师标准做Web开发的朋友一定都知道PNG是一个相当不错的图片格式,但是这个好的格式却在IE6时代造成了麻烦࿰ ... [详细]
  • 编者按:来自自江民科技的消息称,该公司创始人王江民近日因病去世,享年59岁,为了纪念这位中国反病毒事业的知名专家与老前辈,现摘录来自刘韧在知识英雄系列中采访其的一篇文章 王江民,著名的反病毒专家 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
author-avatar
asdfu_814
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有