热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch模型加载测试_「pytorch速成」Pytorch图像分类从模型自定义到测试

原标题:「pytorch速成」Pytorch图像分类从模型自定义到测试前面已跟大家介绍了Caffe和TensorFlow,今天说说Pytorch。1什么

原标题:「pytorch速成」Pytorch图像分类从模型自定义到测试

前面已跟大家介绍了Caffe和TensorFlow,今天说说Pytorch。

1 什么是 Pytorch

一句话总结 Pytorch = Python + Torch。

Torch 是纽约大学的一个机器学习开源框架,几年前在学术界非常流行,包括 Lecun等大佬都在使用。但是由于使用的是一种绝大部分人绝对没有听过的 Lua 语言,导致很多人都被吓退。后来随着 Python 的生态越来越完善,Facebook 人工智能研究院推出了Pytorch并开源。Pytorch不是简单的封装 Torch并提供Python接口,而是对Tensor以上的所有代码进行了重构,同TensorFlow一样,增加了自动求导。

后来Caffe2全部并入Pytorch,如今已经成为了非常流行的框架。很多最新的研究如风格化、GAN 等大多数采用Pytorch源码,这也是我们必须要讲解它的原因。

1.1 特点

(1) 动态图计算。TensorFlow从静态图发展到了动态图机制Eager Execution,pytorch则一开始就是动态图机制。动态图机制的好处就是随时随地修改,随处debug,没有类似编译的过程。

(2) 简单。相比TensorFlow中Tensor、Variable、Session等概念充斥,数据读取接口频繁更新,tf.nn、tf.layers、tf.contrib各自重复,Pytorch则是从Tensor到Variable再到nn.Module,最新的Pytorch已经将Tensor和Variable合并,这分别就是从数据张量到网络的抽象层次的递进。有人调侃TensorFlow的设计是“make it complicated”,那么 Pytorch的设计就是“keep it simple”。

1.2 重要概念

(1) Tensor/Variable

每一个框架都有基本的数据结构,Caffe是blob,TensorFlow和Pytorch都是Tensor,都是高维数组。Pytorch中的Tensor使用与Numpy的数组非常相似,两者可以互转且共享内存。

tensor包括cpu和gpu两种类型,如torch.FloatTensortorch.cuda.FloatTensorvirable,就分别表示cpu和gpu下的32位浮点数。

tensor包含一些属性。data,即Tensor内容;Grad,是与data对应的梯度;requires_grad,是否容许进行反向传播的学习,更多的可以去查看API。

(2) nn.module

抽象好的网络数据结构,可以表示为网络的一层,也可以表示为一个网络结构,这是一个基类。在实际使用过程中,经常会定义自己的网络,并继承nn.Module。具体的使用,我们看下面的网络定义吧。

(3) torchvision包,包含了目前流行的数据集,模型结构和常用的图片转换工具

2 Pytorch 训练

安装咱们就不说了,接下来的任务就是开始训练模型。训练模型包括数据准备、模型定义、结果保存与分析。

2.1 数据读取

前面已经介绍了Caffe和TensorFlow的数据读取,两者的输入都是图片list,但是读取操作过程差异非常大,Pytorch与这两个又有很大的差异。这一次,直接利用文件夹作为输入,这是 Pytorch更加方便的做法。数据读取的完整代码如下:

data_dir = '../../../../datas/head/'

data_transforms = {

'train': transforms.Compose([

transforms.RandomSizedCrop(48),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])

]),

'val': transforms.Compose([

transforms.Scale(64),

transforms.CenterCrop(48),

transforms.ToTensor(),

transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])

]),

}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),

data_transforms[x]) for x in ['train', 'val']}

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],

batch_size=16,

shuffle=True,

num_workers=4) for x in ['train', 'val']}

下面一个一个解释,完整代码请移步 Git 工程。

(1) datasets.ImageFolder

Pytorch的torchvision模块中提供了一个dataset 包,它包含了一些基本的数据集如mnist、coco、imagenet和一个通用的数据加载器ImageFolder。

它会以这样的形式组织数据,具体的请到Git工程中查看。

root/left/1.png

root/left/2.png

root/left/3.png

root/right/1.png

root/right/2.png

root/right/3.png

imagefolder有3个成员变量。

self.classes:用一个list保存类名,就是文件夹的名字。

self.class_to_idx:类名对应的索引,可以理解为 0、1、2、3 等。

self.imgs:保存(imgpath,class),是图片和类别的数组。

不同文件夹下的图,会被当作不同的类,天生就用于图像分类任务。

(2) Transforms

这一点跟Caffe非常类似,就是定义了一系列数据集的预处理和增强操作。到此,数据接口就定义完毕了,接下来在训练代码中看如何使用迭代器进行数据读取就可以了,包括 scale、减均值等。

(3) torch.utils.data.DataLoader

这就是创建了一个 batch,生成真正网络的输入。关于更多 Pytorch 的数据读取方法,请自行学习。

2.2 模型定义

import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np

class simpleconv3(nn.Module):`

def __init__(self):

super(simpleconv3,self).__init__()

self.conv1 = nn.Conv2d(3, 12, 3, 2)

self.bn1 = nn.BatchNorm2d(12)

self.conv2 = nn.Conv2d(12, 24, 3, 2)

self.bn2 = nn.BatchNorm2d(24)

self.conv3 = nn.Conv2d(24, 48, 3, 2)

self.bn3 = nn.BatchNorm2d(48)

self.fc1 = nn.Linear(48 * 5 * 5 , 1200)

self.fc2 = nn.Linear(1200 , 128)

self.fc3 = nn.Linear(128 , 2)

def forward(self , x):

x = F.relu(self.bn1(self.conv1(x)))

x = F.relu(self.bn1(self.conv2(x)))

x = F.relu(self.bn3(self.conv3(x)))

x = x.view(-1 , 48 * 5 * 5)

x = F.relu(self.fc1(x))

x = F.relu(self.fc2(x))

x = self.fc3(x)

return x

我们的例子都是采用一个简单的3层卷积 + 2层全连接层的网络结构。根据上面的网络结构的定义,需要做以下事情。

(1) simpleconv3(nn.Module)

继承nn.Module,前面已经说过,Pytorch的网络层是包含在nn.Module 里,所以所有的网络定义,都需要继承该网络层,并实现super方法,如下:

super(simpleconv3,self).__init__()

这个就当作一个标准执行就可以了。

(2) 网络结构的定义都在nn包里,举例说明:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

完整的接口如上,定义的第一个卷积层如下:

nn.Conv2d(3, 12, 3, 2)

即输入通道为3,输出通道为12,卷积核大小为3,stride=2,其他的层就不一一介绍了,大家可以自己去看nn的API。

(3) forward

backward方法不需要自己实现,但是forward函数是必须要自己实现的,从上面可以看出,forward 函数也是非常简单,串接各个网络层就可以了。

对比Caffe和TensorFlow可以看出,Pytorch的网络定义更加简单,初始化方法都没有显示出现,因为 Pytorch已经提供了默认初始化。

如果我们想实现自己的初始化,可以这么做:

init.xavier_uniform(self.conv1.weight)init.constant(self.conv1.bias, 0.1)

它会对conv1的权重和偏置进行初始化。如果要对所有conv层使用 xavier 初始化呢?可以定义一个函数:

def weights_init(m):

if isinstance(m, nn.Conv2d):

xavier(m.weight.data)

xavier(m.bias.data)

net = Net()

net.apply(weights_init)

3 模型训练

网络定义和数据加载都定义好之后,就可以进行训练了,老规矩先上代码:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):

for epoch in range(num_epochs):

print('Epoch {}/{}'.format(epoch, num_epochs - 1))

for phase in ['train', 'val']:

if phase == 'train':

scheduler.step()

model.train(True)

else:

model.train(False)

running_loss = 0.0 running_corrects = 0.0

for data in dataloders[phase]:

inputs, labels = data

if use_gpu:

inputs = Variable(inputs.cuda())

labels = Variable(labels.cuda())

else:

inputs, labels = Variable(inputs), Variable(labels)

optimizer.zero_grad()

outputs = model(inputs)

_, preds = torch.max(outputs.data, 1)

loss = criterion(outputs, labels)

if phase == 'train':

loss.backward()

optimizer.step()

running_loss += loss.data.item()

running_corrects += torch.sum(preds == labels).item()

epoch_loss = running_loss / dataset_sizes[phase]

epoch_acc = running_corrects / dataset_sizes[phase]

if phase == 'train':

writer.add_scalar('data/trainloss', epoch_loss, epoch)

writer.add_scalar('data/trainacc', epoch_acc, epoch)

else:

writer.add_scalar('data/valloss', epoch_loss, epoch)

writer.add_scalar('data/valacc', epoch_acc, epoch)

print('{} Loss: {:.4f} Acc: {:.4f}'.format(

phase, epoch_loss, epoch_acc))

writer.export_scalars_to_json("./all_scalars.json")

writer.close()

return model

分析一下上面的代码,外层循环是epoches,然后利用 for data in dataloders[phase] 循环取一个epoch 的数据,并塞入variable,送入model。需要注意的是,每一次forward要将梯度清零,即optimizer.zero_grad(),因为梯度会记录前一次的状态,然后计算loss进行反向传播。

loss.backward()

optimizer.step()

下面可以分别得到预测结果和loss,每一次epoch 完成计算。

epoch_loss = running_loss / dataset_sizes[phase]

epoch_acc = running_corrects / dataset_sizes[phase]

_, preds = torch.max(outputs.data, 1)

loss = criterion(outputs, labels)

可视化是非常重要的,鉴于TensorFlow的可视化非常方便,我们选择了一个开源工具包,tensorboardx,安装方法为pip install tensorboardx,使用非常简单。

第一步,引入包定义创建:

from tensorboardX import SummaryWriter

writer = SummaryWriter()

第二步,记录变量,如train阶段的 loss,writer.add_scalar('data/trainloss', epoch_loss, epoch)。

按照以上操作就完成了,完整代码可以看配套的Git 项目,我们看看训练中的记录。Loss和acc的曲线图如下:

网络的收敛没有Caffe和TensorFlow好,大家可以自己去调试调试参数了,随便折腾吧。

4 Pytorch 测试

上面已经训练好了模型,接下来的目标就是要用它来做inference了,同样给出代码。

import torch

import torch.nn as nn

import torch.optim as optim

from torch.optim import lr_scheduler

from torch.autograd import Variable

import torchvision

from torchvision import datasets, models, transforms

import time

import os

from PIL import Image

import sys

import torch.nn.functional as F

from net import simpleconv3

data_transforms = transforms.Compose([

transforms.Resize(48),

transforms.ToTensor(),

transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

net = simpleconv3()

modelpath = sys.argv[1]

net.load_state_dict(torch.load(modelpath,map_location=lambda storage,loc: storage))

imagepath = sys.argv[2]

image = Image.open(imagepath)

imgblob = data_transforms(image).unsqueeze(0)

imgblob = Variable(imgblob)

torch.no_grad()

predict = F.softmax(net(imgblob))

print(predict)

从上面的代码可知,做了几件事:

定义网络并使用torch.load和load_state_dict载入模型。

用PIL的Image包读取图片,这里没有用OpenCV,因为Pytorch默认的图片读取工具就是PIL的Image,它会将图片按照RGB的格式,归一化到 0~1 之间。读取图片之后,必须转化为Tensor变量。

evaluation的时候,必须设置torch.no_grad(),然后就可以调用 softmax 函数得到结果了。

5 总结

本节讲了如何用 Pytorch 完成一个分类任务,并学习了可视化以及使用训练好的模型做测试。

配套资料在github,https://github.com/longpeng2008/yousan.ai。

责任编辑:



推荐阅读
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 本文详细介绍了 Java 网站开发的相关资源和步骤,包括常用网站、开发环境和框架选择。 ... [详细]
  • 本文介绍了如何使用 Google Colab 的免费 GPU 资源进行深度学习应用开发。Google Colab 是一个无需配置即可使用的云端 Jupyter 笔记本环境,支持多种深度学习框架,并且提供免费的 GPU 计算资源。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 2020年9月15日,Oracle正式发布了最新的JDK 15版本。本次更新带来了许多新特性,包括隐藏类、EdDSA签名算法、模式匹配、记录类、封闭类和文本块等。 ... [详细]
  • 在Conda环境中高效配置并安装PyTorch和TensorFlow GPU版的方法如下:首先,创建一个新的Conda环境以避免与基础环境发生冲突,例如使用 `conda create -n pytorch_gpu python=3.7` 命令。接着,激活该环境,确保所有依赖项都正确安装。此外,建议在安装过程中指定CUDA版本,以确保与GPU兼容性。通过这些步骤,可以确保PyTorch和TensorFlow GPU版的顺利安装和运行。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 利用Python与Android进行高效移动应用开发
    通过结合Python和Android,可以实现高效的移动应用开发。首先,需要安装Scripting Layer for Android (SL4A),这是一个开源项目,旨在为Android系统提供脚本语言支持。SL4A不仅简化了开发流程,还允许开发者使用Python等高级语言编写脚本,从而提高开发效率和代码可维护性。此外,SL4A还支持多种其他脚本语言,进一步扩展了其应用范围。通过这种方式,开发者可以快速构建功能丰富的移动应用,同时保持较高的灵活性和可扩展性。 ... [详细]
  • 在Windows环境下离线安装PyTorch GPU版时,首先需确认系统配置,例如本文作者使用的是Win8、CUDA 8.0和Python 3.6.5。用户应根据自身Python和CUDA版本,在PyTorch官网查找并下载相应的.whl文件。此外,建议检查系统环境变量设置,确保CUDA路径正确配置,以避免安装过程中可能出现的兼容性问题。 ... [详细]
  • 利用PaddleSharp模块在C#中实现图像文字识别功能测试
    PaddleSharp 是 PaddleInferenceCAPI 的 C# 封装库,适用于 Windows (x64)、NVIDIA GPU 和 Linux (Ubuntu 20.04) 等平台。本文详细介绍了如何使用 PaddleSharp 在 C# 环境中实现图像文字识别功能,并进行了全面的功能测试,验证了其在多种硬件配置下的稳定性和准确性。 ... [详细]
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
  • 为何Serverless将成为未来十年的主导技术领域?
    为何Serverless将成为未来十年的主导技术领域? ... [详细]
  • 使用Tkinter构建51Ape无损音乐爬虫UI
    本文介绍了如何使用Python的内置模块Tkinter来构建一个简单的用户界面,用于爬取51Ape网站上的无损音乐百度云链接。虽然Tkinter入门相对简单,但在实际开发过程中由于文档不足可能会带来一些不便。 ... [详细]
  • 使用HTML和JavaScript实现视频截图功能
    本文介绍了如何利用HTML和JavaScript实现从远程MP4、本地摄像头及本地上传的MP4文件中截取视频帧,并展示了具体的实现步骤和示例代码。 ... [详细]
author-avatar
南北风味街
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有