热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch 模型 onnx 文件导出及调用详情

这篇文章主要介绍了PyTorch模型onnx文件导出及调用详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下

前言

Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移

PyTorch 所定义的模型为动态图,其前向传播是由类方法定义和实现的

但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升

PyTorch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,

最重要的有三个参数:

  • model:父类为 nn.Module 的模型
  • args:传入 model 的 forward 方法的变量列表,类型应为
  • tuplef:onnx 文件名称的字符串
import torch
from torchvision.models import resnet50
 
file = 'resnet.onnx'
# 声明模型
resnet = resnet50(pretrained=False).eval()
image = torch.rand([1, 3, 224, 224])
# 导出为 onnx 文件
torch.onnx.export(resnet, (image,), file)

onnx 文件可被 Netron 打开,以查看模型结构

基本用法

要在 Python 中运行 onnx 模型,需要下载 onnxruntime

# 选其一即可
pip install onnxruntime        # CPU 版本
pip install onnxruntime-gpu    # GPU 版本

推理时需要借助其中的 InferenceSession,其中较为重要的实例方法有:

  • get_inputs():得到输入变量的列表 (变量属性:name、shape、type)
  • get_outputs():得到输入变量的列表 (变量属性:name、shape、type)run(output_names, input_feed):输入变量为 numpy.ndarray (注意 dtype 应为 float32),使用模型推理并返回输出

可得出 onnx 模型的基本用法:

import onnxruntime as ort
import numpy as np
file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
    1 if ort.get_device() == 'GPU' else 0]
print('设备:', provider)
# 声明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])
# 参考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
    for node in node_list:
        attr = {'name': node.name,
                'shape': node.shape,
                'type': node.type}
        print(attr)
    print('-' * 60)
 
# 得到输入、输出结点的名称
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
                input_feed={input_node_name: image}))

高级 API

为了简化使用步骤,使用类进行封装:

class Onnx_Module(ort.InferenceSession):
    ''' onnx 推理模型
        provider: 优先使用 GPU'''
    provider = ort.get_available_providers()[
        1 if ort.get_device() == 'GPU' else 0]
 
    def __init__(self, file):
        super(Onnx_Module, self).__init__(file, providers=[self.provider])
        # 参考: ort.NodeArg
        self.inputs = [node_arg.name for node_arg in self.get_inputs()]
        self.outputs = [node_arg.name for node_arg in self.get_outputs()]
 
    def __call__(self, *arrays):
        input_feed = {name: x for name, x in zip(self.inputs, arrays)}
        return self.run(self.outputs, input_feed)

在 PyTorch 中,对于卷积神经网络 model 与图像 image,推理的代码为 "model(image)",而使用这个封装的类也是类似:

import numpy as np
file = 'resnet.onnx'
model = Onnx_Module(file)
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model(image))

为了方便观察 Torch 模型与 onnx 模型的速度差异,同时检查两个模型的输出是否一致,又编写了 test 函数

test 方法的参数与 torch.onnx.export 一致,其基本流程为:

  • 得到 Torch 模型的输出,并 print 推断耗时
  • 将 Torch 模型导出为 onnx 文件,将输入变量中的 torch.tensor 转化为 numpy.ndarray
  • 初始化 onnx 模型,得到 onnx 模型的输出,并 print 推断耗时
  • 计算 Torch 模型与 onnx 模型输出的绝对误差的均值
  • 将 onnx 模型 return
class Timer:
    repeat = 3
 
    def __new__(cls, fun, *args, **kwargs):
        import time
        start = time.time()
        for _ in range(cls.repeat): fun(*args, **kwargs)
        cost = (time.time() - start) / cls.repeat
        return cost * 1e3  # ms
 
 
class Onnx_Module(ort.InferenceSession):
    ''' onnx 推理模型
        provider: 优先使用 GPU'''
    provider = ort.get_available_providers()[
        1 if ort.get_device() == 'GPU' else 0]
 
    def __init__(self, file):
        super(Onnx_Module, self).__init__(file, providers=[self.provider])
        # 参考: ort.NodeArg
        self.inputs = [node_arg.name for node_arg in self.get_inputs()]
        self.outputs = [node_arg.name for node_arg in self.get_outputs()]
    def __call__(self, *arrays):
        input_feed = {name: x for name, x in zip(self.inputs, arrays)}
        return self.run(self.outputs, input_feed)
 
    @classmethod
    def test(cls, model, args, file, **export_kwargs):
        # 测试 Torch 的运行时间
        torch_output = model(*args).data.numpy()
        print(f'Torch: {Timer(model, *args):.2f} ms')
        # model: Torch -> onnx
        torch.onnx.export(model, args, file, **export_kwargs)
        # data: tensor -> array
        args = tuple(map(lambda tensor: tensor.data.numpy(), args))
        onnx_model = cls(file)
        # 测试 onnx 的运行时间
        onnx_output = onnx_model(*args)
        print(f'Onnx: {Timer(onnx_model, *args):.2f} ms')
        # 计算 Torch 模型与 onnx 模型输出的绝对误差
        abs_error = np.abs(torch_output - onnx_output).mean()
        print(f'Mean Error: {abs_error:.2f}')
        return onnx_model

对于 ResNet50 而言,Torch 模型的推断耗时为 172.67 ms,onnx 模型的推断耗时为 36.56 ms,onnx 模型的推断耗时仅为 Torch 模型的 21.17%

到此这篇关于PyTorch 模型 onnx 文件导出及调用详情的文章就介绍到这了,更多相关PyTorch文件导出内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!


推荐阅读
  • 在Windows系统中安装TensorFlow GPU版的详细指南与常见问题解决
    在Windows系统中安装TensorFlow GPU版是许多深度学习初学者面临的挑战。本文详细介绍了安装过程中的每一个步骤,并针对常见的问题提供了有效的解决方案。通过本文的指导,读者可以顺利地完成安装并避免常见的陷阱。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式
    大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式 ... [详细]
  • Python错误重试让多少开发者头疼?高效解决方案出炉
    ### 优化后的摘要在处理 Python 开发中的错误重试问题时,许多开发者常常感到困扰。为了应对这一挑战,`tenacity` 库提供了一种高效的解决方案。首先,通过 `pip install tenacity` 安装该库。使用时,可以通过简单的规则配置重试策略。例如,可以设置多个重试条件,使用 `|`(或)和 `&`(与)操作符组合不同的参数,从而实现灵活的错误重试机制。此外,`tenacity` 还支持自定义等待时间、重试次数和异常处理,为开发者提供了强大的工具来提高代码的健壮性和可靠性。 ... [详细]
  • 在机器学习领域,深入探讨了概率论与数理统计的基础知识,特别是这些理论在数据挖掘中的应用。文章重点分析了偏差(Bias)与方差(Variance)之间的平衡问题,强调了方差反映了不同训练模型之间的差异,例如在K折交叉验证中,不同模型之间的性能差异显著。此外,还讨论了如何通过优化模型选择和参数调整来有效控制这一平衡,以提高模型的泛化能力。 ... [详细]
  • 在Conda环境中高效配置并安装PyTorch和TensorFlow GPU版的方法如下:首先,创建一个新的Conda环境以避免与基础环境发生冲突,例如使用 `conda create -n pytorch_gpu python=3.7` 命令。接着,激活该环境,确保所有依赖项都正确安装。此外,建议在安装过程中指定CUDA版本,以确保与GPU兼容性。通过这些步骤,可以确保PyTorch和TensorFlow GPU版的顺利安装和运行。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 在Windows环境下离线安装PyTorch GPU版时,首先需确认系统配置,例如本文作者使用的是Win8、CUDA 8.0和Python 3.6.5。用户应根据自身Python和CUDA版本,在PyTorch官网查找并下载相应的.whl文件。此外,建议检查系统环境变量设置,确保CUDA路径正确配置,以避免安装过程中可能出现的兼容性问题。 ... [详细]
  • 从2019年AI顶级会议最佳论文,探索深度学习的理论根基与前沿进展 ... [详细]
  • 视觉图像的生成机制与英文术语解析
    近期,Google Brain、牛津大学和清华大学等多家研究机构相继发布了关于多层感知机(MLP)在视觉图像分类中的应用成果。这些研究深入探讨了MLP在视觉任务中的工作机制,并解析了相关技术术语,为理解视觉图像生成提供了新的视角和方法。 ... [详细]
  • MicrosoftDeploymentToolkit2010部署培训实验手册V1.0目录实验环境说明3实验环境虚拟机使用信息3注意:4实验手册正文说 ... [详细]
  • 本文介绍了如何利用 `matplotlib` 库中的 `FuncAnimation` 类将 Python 中的动态图像保存为视频文件。通过详细解释 `FuncAnimation` 类的参数和方法,文章提供了多种实用技巧,帮助用户高效地生成高质量的动态图像视频。此外,还探讨了不同视频编码器的选择及其对输出文件质量的影响,为读者提供了全面的技术指导。 ... [详细]
  • 在《Python编程基础》课程中,我们将深入探讨Python中的循环结构。通过详细解析for循环和while循环的语法与应用场景,帮助初学者掌握循环控制语句的核心概念和实际应用技巧。此外,还将介绍如何利用循环结构解决复杂问题,提高编程效率和代码可读性。 ... [详细]
  • Python与R语言在功能和应用场景上各有优势。尽管R语言在统计分析和数据可视化方面具有更强的专业性,但Python作为一种通用编程语言,适用于更广泛的领域,包括Web开发、自动化脚本和机器学习等。对于初学者而言,Python的学习曲线更为平缓,上手更加容易。此外,Python拥有庞大的社区支持和丰富的第三方库,使其在实际应用中更具灵活性和扩展性。 ... [详细]
author-avatar
cc_lzx_530
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有