热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch构建与部署深度学习模型(1)简单应用

PyTorch是最受欢迎的深度学习Python库之一,它被人工智能研究社区广泛使用。许多开发者和研究人员使用PyTorch来加速深度学习研究实验和原型设计。1.为什么使用PyTor



PyTorch是最受欢迎的深度学习Python库之一,它被人工智能研究社区广泛使用。许多开发者和研究人员使用PyTorch来加速深度学习研究实验和原型设计。


1.为什么使用PyTorch

如果你正在学习机器学习,进行深度学习研究,或构建人工智能系统,你可能需要使用深度学习框架。深度学习框架可以很容易地完成数据加载、预处理、模型设计、训练和部署等常见任务。PyTorch由于其简单、灵活和Python接口,已经在学术和研究团体中非常受欢迎。
以下是学习和使用PyTorch的一些原因:


  • PyTorch很受欢迎
  • PyTorch得到所有主流云平台的支持,如Amazon Web Services (AWS)、谷歌云平台(GCP)、微软Azure、阿里云等
  • PyTorch得到Google Colaboratory和Kaggle Kernels支持
  • PyTorch成熟稳定
  • PyTorch支持CPU、GPU、TPU和并行处理
  • PyTorch支持分布式训练:您可以在多台机器上的多个gpu上训练神经网络。
  • PyTorch支持部署到生产环境:使用新的TorchScript和TorchServe特性,您可以轻松地将模型部署到包括云服务器在内的生产环境中。
  • PyTorch开始支持移动端部署:虽然目前还处于试验阶段,但你现在可以将模型部署到iOS和Android设备上。
  • PyTorch拥有一个庞大的生态系统和一组开源库:Torchvision、fastai和PyTorch Lightning等库扩展了功能,并支持自然语言处理(NLP)和计算机视觉等特定领域。
  • PyTorch也有一个c++前端:虽然在本书中我将重点关注Python接口,PyTorch还支持前端c++接口。如果您需要构建高性能、低延迟应用程序,您可以使用与Python API相同的设计和架构,用c++编写它们。
  • PyTorch本身支持ONNX格式:你可以很容易地将你的模型导出为ONNX格式,并在与ONNX兼容的平台、运行时或可视化工具中使用它们。
  • PyTorch拥有一个庞大的开发者社区和用户论坛:https://pytorch.tips/discuss

2.小试牛刀

在实践中,您将在代码的开头导入所有必要的库。但是,在本例中,我们将在使用库时导入它们,这样您就可以看到每个任务需要哪些库。

首先,让我们选择一个我们想分类的图像。在这个例子中,我们会选择一杯新鲜、热的咖啡。使用以下代码下载咖啡映像到您的本地环境:

import urllib.request
url = 'https://upload.wikimedia.org/wikipedia/commons/4/45/A_small_cup_of_coffee.JPG'
fpath = 'coffee.jpg'
urllib.request.urlretrieve(url, fpath)

注意,代码使用了urllib库的urlretrieve()函数从web获取图像。通过指定fpath,将文件重命名为coffee.jpg

接下来,我们使用PIL读取我们的本地图像:

import matplotlib.pyplot as plt
from PIL import Image
img = Image.open('coffee.jpg')
plt.imshow(img)

在这里插入图片描述
注意,我们还没有使用PyTorch。接下来,我们将把图像传递给一个经过预处理的图像分类神经网络(NN),但在此之前,我们需要对图像进行预处理。预处理数据在机器学习中很常见,因为神经网络期望输入满足一定的要求。

在我们的示例中,图像数据是RGB 1600 × 1200像素JPEG格式。我们需要应用一系列被称为transforms的预处理步骤,将图像转换成适合NN的格式。我们在以下代码中使用Torchvision实现了这一点:

import torch
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
img_tensor = transform(img)
print(type(img_tensor), img_tensor.shape)
# torch.Size([3, 224, 224])

我们使用Compose()变换来定义一系列对图像进行预处理的变换。首先,我们需要调整和裁剪图像以适应NN。图像目前是PIL格式,因为这是我们之前读取它的方式。但是我们的神经网络需要一个张量输入,所以我们把PIL图像转换成一个张量。

张量是PyTorch中最基本的数据对象,我们将在下一章中详细介绍它们。你可以考虑像NumPy数组或数字数组这样的张量,它们有一些额外的特性。现在,我们只需要把图像转换成一个数字张量数组就可以了。将像素值的范围缩放到0到1之间。

我们再应用一个称为Normalize()的变换,将[0,1]范围内的像素值进行标准化。平均值和标准差(std)的值是根据用来训练模型的数据预先计算的。对图像进行归一化可以提高分类器的精度。

最后,我们调用transform(img)对图像应用所有的变换。正如你所看到的,img_tensor是一个3 × 224 × 224torch.Tensor,代表224 × 224像素的3通道图像的张量。

高效的机器学习是批量处理数据的,我们的模型期望得到一批数据。然而,我们只有一个图像,所以我们需要创建一个大小为1的批处理,如下所示的代码:

batch = img_tensor.unsqueeze(0)
print(batch.shape)
# out: torch.Size([1, 3, 224, 224])

我们使用PyTorch的unsqueeze()函数向张量添加一个维度,并创建一个大小为1的批。现在我们有一个大小为1 × 3 × 224 × 224的张量。

现在我们的图像已经为分类器NN准备好了!我们将使用一个名为AlexNet的著名图像分类器。AlexNet赢得了2012ImageNet大型视觉识别挑战赛。使用Torchvision加载这个模型很容易,如下所示的代码:

from torchvision import models
model = models.alexnet(pretrained=True)

我们将使用一个预先训练过的模型,所以我们不需要训练它。AlexNet模型已经用数百万张图像进行了预训练,在对图像进行分类方面做得很好。让我们传入我们的图像,看看它是如何做到的:

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# out(results will vary): cpu
model.eval()
model.to(device)
y = model(batch.to(device))
print(y.shape)
# out: torch.Size([1, 1000])

GPU加速是PyTorch的一个关键优势。在第一行中,我们使用PyTorch的cuda.is_available()函数来查看我们的机器是否有GPU。我们只对一张图像进行分类,所以我们不需要GPU,但如果我们有大量的图像,那么GPU可能会帮助我们加快速度。

model.eval()函数配置我们的AlexNet模型进行推断或预测(而不是训练)。模型的某些组成部分只在训练中使用,我们不想在这里使用它们。使用model.to(device)batch.to(device)将我们的模型和输入数据发送给GPU(如果可用的话),并执行model(batch.to(device))运行我们的分类器。

输出y1,000个输出组成。因为我们的批处理只包含一个图像,第一个维度是1,而类的数量是1000,每个类一个值。值越高,图像包含该类的可能性越大。下面的代码找到获胜的类:

y_max, index = torch.max(y,1)
print(index, y_max)
# out: tensor([967]) tensor([22.3059],
# grad_fn=)

使用PyTorchmax()函数,我们看到index967的值最大,为22.3059,因此是优胜者。然而,我们不知道967类代表什么。让我们用类名加载这个文件,并找出答案:

import urllib.request
url = "https://raw.githubusercontent.com/joe-papa/pytorch-book/main/files/imagenet_class_labels.txt"
fpath = 'imagenet_class_labels.txt'
urllib.request.urlretrieve(url, fpath)
with open('imagenet_class_labels.txt') as f:
classes = [line.strip() for line in f.readlines()]
print(classes[967])
# out: 967: 'espresso',

与前面一样,我们使用urlretrieve()下载包含每个类描述的文本文件。然后,我们使用readlines()读取文件,并创建一个包含类名的列表。当我们print(classes[967])时,它告诉我们967类是浓缩咖啡!

使用PyTorchsoftmax()函数,我们可以将输出值转换为概率:

prob = torch.nn.functional.softmax(y, dim=1)[0] * 100
print(classes[index[0]], prob[index[0]].item())
#967: 'espresso', 87.85208892822266

要打印索引处的概率,我们使用PyTorchtensor.item()方法。item()方法经常被使用,它返回一个张量中包含的数值。实验结果表明,该模型对该图像的确定度为87.85%。

我们可以使用PyTorchsort()函数对输出概率进行排序,看看前5个:

_, indices = torch.sort(y, descending=True)
for idx in indices[0][:5]:
print(classes[idx], prob[idx].item())
# out:
# 967: 'espresso', 87.85208892822266
# 968: 'cup', 7.28359317779541
# 504: 'coffee mug', 4.33521032333374
# 925: 'consomme', 0.36686763167381287
# 960: 'chocolate sauce, chocolate syrup',
# 0.09037172049283981

我们看到模型预测图像是87.85%的浓缩咖啡,7.28%的杯子和4.3%的咖啡杯的概率,但它似乎相当确信图像是浓缩咖啡。


3.完整代码

你可能觉得你现在就需要一杯浓缩咖啡。在那个例子中我们已经讨论了很多!完成所有事情的核心代码实际上要短得多。假设你已经下载了这些文件,你只需要运行以下代码来使用AlexNet对图像进行分类:

import torch
from torchvision import transforms, models
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
img_tensor = transform(img)
batch = img_tensor.unsqueeze(0)
model = models.alexnet(pretrained=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
y = model(batch.to(device))
prob = torch.nn.functional.softmax(y, dim=1)[0] * 100
_, indices = torch.sort(y, descending=True)
for idx in indices[0][:5]:
print(classes[idx], prob[idx].item())

这就是如何用PyTorch构建一个图像分类器。试着在模型中运行您自己的图像,看看它是如何对它们进行分类的。



推荐阅读
  • Android Studio Bumblebee | 2021.1.1(大黄蜂版本使用介绍)
    本文介绍了Android Studio Bumblebee | 2021.1.1(大黄蜂版本)的使用方法和相关知识,包括Gradle的介绍、设备管理器的配置、无线调试、新版本问题等内容。同时还提供了更新版本的下载地址和启动页面截图。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • 拥抱Android Design Support Library新变化(导航视图、悬浮ActionBar)
    转载请注明明桑AndroidAndroid5.0Loollipop作为Android最重要的版本之一,为我们带来了全新的界面风格和设计语言。看起来很受欢迎࿰ ... [详细]
  • 域名解析系统DNS
    文章目录前言一、域名系统概述二、因特网的域名结构三、域名服务器1.根域名服务器2.顶级域名服务器(TLD,top-leveldomain)3.权威(Authoritative)域名 ... [详细]
  • 2017亚马逊人工智能奖公布:他们的AI有什么不同?
    事实上,在我们周围,“人工智能”让一切都变得更“智能”极具讽刺意味。随着人类与机器智能之间的界限变得模糊,我们的世界正在变成一个机器 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 本文介绍了如何找到并终止在8080端口上运行的进程的方法,通过使用终端命令lsof -i :8080可以获取在该端口上运行的所有进程的输出,并使用kill命令终止指定进程的运行。 ... [详细]
  • 本文介绍了腾讯最近开源的BERT推理模型TurboTransformers,该模型在推理速度上比PyTorch快1~4倍。TurboTransformers采用了分层设计的思想,通过简化问题和加速开发,实现了快速推理能力。同时,文章还探讨了PyTorch在中间层延迟和深度神经网络中存在的问题,并提出了合并计算的解决方案。 ... [详细]
  • 如何使用代理服务器进行网页抓取?
    本文介绍了如何使用代理服务器进行网页抓取,并探讨了数据驱动对竞争优势的重要性。通过网页抓取,企业可以快速获取并分析大量与需求相关的数据,从而制定营销战略。同时,网页抓取还可以帮助电子商务公司在竞争对手的网站上下载数百页的有用数据,提高销售增长和毛利率。 ... [详细]
  • 本文介绍了RxJava在Android开发中的广泛应用以及其在事件总线(Event Bus)实现中的使用方法。RxJava是一种基于观察者模式的异步java库,可以提高开发效率、降低维护成本。通过RxJava,开发者可以实现事件的异步处理和链式操作。对于已经具备RxJava基础的开发者来说,本文将详细介绍如何利用RxJava实现事件总线,并提供了使用建议。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 前言:拿到一个案例,去分析:它该是做分类还是做回归,哪部分该做分类,哪部分该做回归,哪部分该做优化,它们的目标值分别是什么。再挑影响因素,哪些和分类有关的影响因素,哪些和回归有关的 ... [详细]
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
author-avatar
待续爱情2502861755
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有