热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

python图像分类_使用PyTorch训练一个图像分类器实例

如下所示:?123456789importtorchimporttorchvisionimporttorchvision.transform

如下所示:

?

1

2

3

4

5

6

7

8

9

import torch

import torchvision

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import numpy as np

print("torch: %s" % torch.__version__)

print("tortorchvisionch: %s" % torchvision.__version__)

print("numpy: %s" % np.__version__)

Out:

?

1

2

3

torch:1.0.0

tortorchvisionch:0.2.1

numpy:1.15.4

数据从哪儿来?

通常来说,你可以通过一些python包来把图像、文本、音频和视频数据加载为numpy array。然后将其转换为torch.*Tensor。

图像。Pillow、OpenCV是用得比较多的

音频。scipy和librosa

文本。纯Python或者Cython就可以完成数据加载,可以在NLTK和SpaCy找到数据

对于计算机视觉而言,我们有torchvision包,它可以用来加载一下常用数据集如Imagenet、CIFAR10、MINIST等等,也有一些常用的为图像准备数据转换例如torchvision.datasets和torch.utils.data.DataLoader。

这次的教程中,我们使用CIFAR10数据集,他有‘airplane', ‘automobile', ‘bird', ‘cat', ‘deer', ‘dog', ‘frog', ‘horse', ‘ship', ‘truck'这几个类别的图像。图像大小都是3x32x32的。也就是说,图像都是三通道的,每一张图的尺寸都是32x32。

1-2005011123243R.jpg

训练一个图像分类器

步骤如下:

使用torchvision加载、归一化训练集和测试集

定义卷积神经网络

定义损失函数

使用训练集训练网络

使用测试集测试网络

1. 加载、归一化CIFAR10

我们可以使用torchvision很轻松的完成

torchvision的数据集是基于PILImage的,数值是[0, 1],我们需要将其转成范围为[-1, 1]的Tensor

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

transform= transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))

])

trainset= torchvision.datasets.CIFAR10(root='./data', train=True,

download=True, transform=transform)

trainloader= torch.utils.data.DataLoader(trainset, batch_size=4,

shuffle=True, num_workers=4)

testset= torchvision.datasets.CIFAR10(root='./data', train=False,

download=True, transform=transform)

testloader= torch.utils.data.DataLoader(testset, batch_size=4,

shuffle=True, num_workers=4)

classes= ('plane','car','bird','cat',

'deer','dog','frog','horse','ship','truck')

Out:

?

1

2

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz

Files already downloadedand verified

让我们来看看训练集的图片

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

# 显示一张图片

def imshow(img):

img= img/ 2 + 0.5 # 逆归一化

npimg= img.numpy()

plt.imshow(np.transpose(npimg, (1,2,0)))

plt.show()

# 任意地拿到一些图片

dataiter= iter(trainloader)

images, labels= dataiter.next()

# 显示图片

imshow(torchvision.utils.make_grid(images))

# 显示类标

print(' '.join('%5s' % classes[labels[j]]for jin range(4)))

Out:

1-200501112333307.jpg

?

1

truck dog ship dog

2. 定义卷积神经网络

可以直接复制神经网络的代码,修改里面的几层即可。

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):

super(Net,self).__init__()

self.conv1= nn.Conv2d(3,6,5)

self.pool= nn.MaxPool2d(2,2)

self.conv2= nn.Conv2d(6,16,5)

self.fc1= nn.Linear(16 * 5 * 5,120)

self.fc2= nn.Linear(120,84)

self.fc3= nn.Linear(84,10)

def forward(self, x):

x= self.pool(F.relu(self.conv1(x)))

x= self.pool(F.relu(self.conv2(x)))

x= x.view(-1,16 * 5 * 5)

x= F.relu(self.fc1(x))

x= F.relu(self.fc2(x))

x= self.fc3(x)

return x

net= Net()

3. 定义损失函数和优化器

使用多分类交叉熵损失函数,和带有momentum的SGD作为优化器

?

1

2

3

4

import torch.optim as optim

criterion= nn.CrossEntropyLoss()

optimizer= optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)

4. 训练网络

我们直接使用循环语句遍历数据集即可完成训练

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

nums_epoch= 2

for epochin range(nums_epoch):

_loss= 0.0

for i, (inputs, labels)in enumerate(trainloader,0):

inputs, labels= inputs.to(device), labels.to(device)

optimizer.zero_grad()

outputs= net(inputs)

loss= criterion(outputs, labels)

loss.backward()

optimizer.step()

_loss+= loss.item()

if i% 2000 == 1999:# 每2000步打印一次损失值

print('[%d, %5d] loss: %.3f' %

(epoch+ 1, i+ 1, _loss/ 2000))

_loss= 0.0

print('Finished Training')

Out:

?

1

2

3

4

5

6

7

8

9

10

11

12

13

[1,2000] loss:1.178

[1,4000] loss:1.200

[1,6000] loss:1.168

[1,8000] loss:1.175

[1,10000] loss:1.185

[1,12000] loss:1.165

[2,2000] loss:1.073

[2,4000] loss:1.066

[2,6000] loss:1.100

[2,8000] loss:1.107

[2,10000] loss:1.083

[2,12000] loss:1.103

Finished Training

5. 测试网络

这个网络已经训练了两个epoch,我们现在来看看这个网络是不是学到了一些什么东西。

我们让这个神经网络预测几张图片,看看它的答案与真实答案的差别。

下面我们选取一些测试数据集中的数据,看看他们的真实标签。

?

1

2

3

4

5

# 展示测试数据集

dataiter= iter(testloader)

images, labels= dataiter.next()

imshow(torchvision.utils.make_grid(images))

print('GraoundTruth: ',' '.join(['%5s' % classes[labels[j]]for jin range(4)]))

Out:

1-200501112344Y3.jpg

?

1

GraoundTruth: ship ship deer ship

接着我们让神经网络来给出预测标签

神经网络的输出是10个信号值,信号值最高的那个神经元表示整个网络的预测值,所以我们需要拿到信号最强的那个节点的索引值

?

1

2

3

4

# 展示预测值

outputs= net(images)

_, predicted= torch.max(outputs,1)

print('Predicted: ',' '.join(['%5s' % classes[predicted[j]]for jin range(4)]))

Out:

?

1

Predicted: car ship horse ship

下面我们对整个测试集做一次评估:

?

1

2

3

4

5

6

7

8

9

10

11

# 评估测试数据集

correct, total= 0,0

with torch.no_grad():

for images, labelsin testloader:

outputs= net(images)

_, predicted= torch.max(outputs,1)

total+= labels.size(0)

correct+= (labels== predicted).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (

100 * correct/ total))

Out:

?

1

Accuracy of the network on the10000 test images:58 %

整个结果比随机猜要好得多(随机猜是10%的概率)。看来我们的神经网络还是学到了点东西。

下面我们来看看它在哪一个类别的分类上做得最好:

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

# 按类标评估

n_classes= len(classes)

class_correct, class_total= [0]*n_classes, [0]* n_classes

with torch.no_grad():

for images, labelsin testloader:

outputs= net(images)

_, predicted= torch.max(outputs,1)

is_correct= (labels== predicted).squeeze()

for iin range(len(labels)):

label= labels[i]

class_total[label]+= 1

class_correct[label]+= is_correct[i].item()

for iin range(n_classes):

print('Accuracy of %5s: %.2f %%' % (

classes[i],100.0 * class_correct[i]/ class_total[i]

))

Out:

?

1

2

3

4

5

6

7

8

9

10

Accuracy of plane:67.00 %

Accuracy of car:71.50 %

Accuracy of bird:55.20 %

Accuracy of cat:45.60 %

Accuracy of deer:38.20 %

Accuracy of dog:47.00 %

Accuracy of frog:78.80 %

Accuracy of horse:55.90 %

Accuracy of ship:72.70 %

Accuracy of truck:57.50 %

在GPU上训练

就像把Tensor从CPU转移到GPU一样,神经网络也可以转移到GPU上

首先需要检查是否有可用的GPU

?

1

2

3

4

device= torch.device("cuda:0" if torch.cuda.is_available()else "cpu")

# 假设我们在支持CUDA的机器上,我们可以打印出CUDA设备:

print(device)

Out:

?

1

cuda:0

我们假设device已经是CUDA设备了

下面命令将递归的将所有模块和参数、缓存转移到CUDA设备上去

?

1

net.to(device)

Out:

?

1

2

3

4

5

6

7

8

Net(

(conv1): Conv2d(3,6, kernel_size=(5,5), stride=(1,1))

(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

(conv2): Conv2d(6,16, kernel_size=(5,5), stride=(1,1))

(fc1): Linear(in_features=400, out_features=120, bias=True)

(fc2): Linear(in_features=120, out_features=84, bias=True)

(fc3): Linear(in_features=84, out_features=10, bias=True)

)

注意,在训练过程中的传入输入数据时,也需要转移到GPU上

并且,需要重新实例化优化器,否则会报错

?

1

inputs, labels= inputs.to(device), labels.to(device)

练习:尝试增加神经网络的宽度。第一个nn.Conv2d的第二个参数和第二个nn.Conv2d的第一个参数的值必须一样。看看会有什么样的效果。

以上这篇使用PyTorch训练一个图像分类器实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/TinyJian/article/details/86617064



推荐阅读
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • EzPP 0.2发布,新增YAML布局渲染功能
    EzPP发布了0.2.1版本,新增了YAML布局渲染功能,可以将YAML文件渲染为图片,并且可以复用YAML作为模版,通过传递不同参数生成不同的图片。这个功能可以用于绘制Logo、封面或其他图片,让用户不需要安装或卸载Photoshop。文章还提供了一个入门例子,介绍了使用ezpp的基本渲染方法,以及如何使用canvas、text类元素、自定义字体等。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • 解决python matplotlib画水平直线的问题
    本文介绍了在使用python的matplotlib库画水平直线时可能遇到的问题,并提供了解决方法。通过导入numpy和matplotlib.pyplot模块,设置绘图对象的宽度和高度,以及使用plot函数绘制水平直线,可以解决该问题。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • 本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。 ... [详细]
author-avatar
木_妍_595
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有