热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

开发笔记:他山之石在C++平台上部署PyTorch模型流程+踩坑实录

篇首语:本文由编程笔记#小编为大家整理,主要介绍了他山之石在C++平台上部署PyTorch模型流程+踩坑实录相关的知识,希望对你有一定的参考价值。

篇首语:本文由编程笔记#小编为大家整理,主要介绍了他山之石在C++平台上部署PyTorch模型流程+踩坑实录相关的知识,希望对你有一定的参考价值。















最近因为工作需要,要把pytorch的模型部署到c++平台上,基本过程主要参照官网的教学示例,期间发现了不少坑,特此记录。








作者:火星少女




01




























模型转换



libtorch不依赖于python,python训练的模型,需要转换为script model才能由libtorch加载,并进行推理。在这一步官网提供了两种方法:


方法一:Tracing


这种方法操作比较简单,只需要给模型一组输入,走一遍推理网络,然后由torch.ji.trace记录一下路径上的信息并保存即可。示例如下:

























import torchimport torchvision
# An instance of your model.model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.traced_script_module = torch.jit.trace(model, example)





缺点是如果模型中存在控制流比如if-else语句,一组输入只能遍历一个分支,这种情况下就没办法完整的把模型信息记录下来。


方法二:Scripting


直接在Torch脚本中编写模型并相应地注释模型,通过torch.jit.script编译模块,将其转换为ScriptModule。示例如下:




























class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output
my_module = MyModule(10,20)sm = torch.jit.script(my_module)







  • forward方法会被默认编译,forward中被调用的方法也会按照被调用的顺序被编译



  • 如果想要编译一个forward以外且未被forward调用的方法,可以添加 @torch.jit.export.



  • 如果想要方法不被编译,可使用@torch.jit.ignore[1] 或者 @torch.jit.unused[2]





































# Same behavior as pre-PyTorch 1.2@torch.jit.scriptdef some_fn(): return 2
# Marks a function as ignored, if nothing# ever calls it then this has no effect@torch.jit.ignoredef some_fn2(): return 2
# As with ignore, if nothing calls it then it has no effect.# If it is called in script it is replaced with an exception.@torch.jit.unuseddef some_fn3(): import pdb; pdb.set_trace() return 4
# Doesn't do anything, this function is already# the main entry point@torch.jit.exportdef some_fn4(): return 2





在这一步遇到好多坑,主要原因可归为一下两点


1. 不支持的操作


TorchScript支持的操作是python的子集,大部分torch中用到的操作都可以找到对应实现,但也存在一些尴尬的不支持操作,详细列表可见unsupported-ops[3],下面列一些我自己遇到的操作:


1)参数/返回值不支持可变个数,例如


















def __init__(self, **kwargs):

或者



















if output_flag == 0: return reshape_logitselse: loss = self.loss(reshape_logits, term_mask, labels_id) return reshape_logits, loss





2)各种iteration操作


eg1.


















layers = [int(a) for a in layers]

报错torch.jit.frontend.UnsupportedNodeError: ListComp aren’t supported



可以改成:
















for k in range(len(layers)): layers[k] = int(layers[k])





eg2.



















seq_iter = enumerate(scores)try: _, inivalues = seq_iter.__next__()except: _, inivalues = seq_iter.next()





eg3.


















line = next(infile)

3)不支持的语句



eg1. 不支持continue


torch.jit.frontend.UnsupportedNodeError: continue statements aren’t supported


eg2. 不支持try-catch


torch.jit.frontend.UnsupportedNodeError: try blocks aren’t supported


eg3. 不支持with语句


4)其他常见op/module


eg1. torch.autograd.Variable


解决:使用torch.ones/torch.randn等初始化+.float()/.long()等指定数据类型。


eg2. torch.Tensor/torch.LongTensor etc.


解决:同上


eg3. requires_grad参数只在torch.tensor中支持,torch.ones/torch.zeros等不可用


eg4. tensor.numpy()


eg5. tensor.bool()


解决:tensor.bool()用tensor>0代替


eg6. self.seg_emb(seg_fea_ids).to(embeds.device)


解决:需要转gpu的地方显示调用.cuda()


总之一句话:除了原生python和pytorch以外的库,比如numpy什么的能不用就不用,尽量用pytorch的各种API。


2.指定数据类型


1)属性,大部分的成员数据类型可以根据值来推断,空的列表/字典则需要预先指定































from typing import Dict
class MyModule(torch.nn.Module): my_dict: Dict[str, int]
def __init__(self): super(MyModule, self).__init__() # This type cannot be inferred and must be specified self.my_dict = {}
# The attribute type here is inferred to be `int` self.my_int = 20
def forward(self): pass
m = torch.jit.script(MyModule())





2)常量,使用Final关键字

































try: from typing_extensions import Finalexcept: # If you don't have `typing_extensions` installed, you can use a # polyfill from `torch.jit`. from torch.jit import Final
class MyModule(torch.nn.Module):
my_constant: Final[int]
def __init__(self): super(MyModule, self).__init__() self.my_cOnstant= 2
def forward(self): pass
m = torch.jit.script(MyModule())





3)变量。默认是tensor类型且不可变,所以非tensor类型必须要指明


















def forward(self, batch_size:int, seq_len:int, use_cuda:bool):

方法三:Tracing and Scriptin混合



一种是在trace模型中调用script,适合模型中只有一小部分需要用到控制流的情况,使用实例如下:





























import torch
@torch.jit.scriptdef foo(x, y): if x.max() > y.max(): r = x else: r = y return r

def bar(x, y, z): return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))





另一种情况是在script module中用tracing生成子模块,对于一些存在script module不支持的python feature的layer,就可以把相关layer封装起来,用trace记录相关layer流,其他layer不用修改。使用示例如下:





























import torchimport torchvision
class MyScriptModule(torch.nn.Module): def __init__(self): super(MyScriptModule, self).__init__() self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) .resize_(1, 3, 1, 1)) self.resnet = torch.jit.trace(torchvision.models.resnet18(), torch.rand(1, 3, 224, 224))
def forward(self, input): return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())





02




























保存序列化模型



如果上一步的坑都踩完,那么模型保存就非常简单了,只需要调用save并传递一个文件名即可,需要注意的是如果想要在gpu上训练模型,在cpu上做inference,一定要在模型save之前转化,再就是记得调用model.eval(),形如






















gpu_model.eval()cpu_model = gpu_model.cpu()sample_input_cpu = sample_input_gpu.cpu()traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)torch.jit.save(traced_cpu, "cpu.pth")
traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)torch.jit.save(traced_gpu, "gpu.pth")





03




























C++ load训练好的模型



要在C ++中加载序列化的PyTorch模型,必须依赖于PyTorch C ++ API(也称为LibTorch)。libtorch的安装非常简单,只需要在pytorch官网下载对应版本,解压即可。会得到一个结构如下的文件夹。



















libtorch/ bin/ include/ lib/ share/





然后就可以构建应用程序了,一个简单的示例目录结构如下:

















example-app/ CMakeLists.txt example-app.cpp





example-app.cpp和CMakeLists.txt的示例代码分别如下:





































#include // One-stop header.#include <iostream>#include int main(int argc, const char* argv[]) { if (argc != 2) { std::cerr <<"usage: example-app \n"; return -1;  }
torch::jit::script::Module module; try { // Deserialize the ScriptModule from a file using torch::jit::load(). module = torch::jit::load(argv[1]); } catch (const c10::Error& e) { std::cerr <<"error loading the model\n"; return -1; }
std::cout <<"ok\n";}






















cmake_minimum_required(VERSION 3.0 FATAL_ERROR)project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)target_link_libraries(example-app "${TORCH_LIBRARIES}")set_property(TARGET example-app PROPERTY CXX_STANDARD 14)





至此,就可以运行以下命令从example-app/文件夹中构建应用程序啦:


















mkdir buildcd buildcmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..cmake --build . --config Release





其中/path/to/libtorch是之前下载后的libtorch文件夹所在的路径。这一步如果顺利能够看到编译完成100%的提示,下一步运行编译生成的可执行文件,会看到“ok”的输出,可喜可贺!


04




























执行Script Module



终于到最后一步啦!下面只需要按照构建输入传给模型,执行forward就可以得到输出啦。一个简单的示例如下:





















// Create a vector of inputs.std::vector inputs;inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.at::Tensor output = module.forward(inputs).toTensor();std::cout </*dim=*/1, /*start=*/0, /*end=*/5) <<'\n';





前两行创建一个torch::jit::IValue的向量,并添加单个输入. 使用torch::ones()创建输入张量,等效于C ++ API中的torch.ones。然后,运行script::Module的forward方法,通过调用toTensor()将返回的IValue值转换为张量。C++对torch的各种操作还是比较友好的,通过torch::或者后加_的方法都可以找到对应实现,例如
















torch::tensor(input_list[j]).to(at::kLong).resize_({batch, 128}).clone()//torch::tensor对应pytorch的torch.tensor; at::kLong对应torch.int64;resize_对应resize





最后check一下确保c++端的输出和pytorch是一致的就大功告成啦~


踩了无数坑,薅掉了无数头发,很多东西也是自己一点点摸索的,如果有错误欢迎指正!






参考资料:

[1] https://pytorch.org/docs/master/generated/torch.jit.ignore.html#torch.jit.ignore

[2] https://pytorch.org/docs/master/generated/torch.jit.unused.html#torch.jit.unused

[3] https://pytorch.org/docs/master/jit_unsupported.html#jit-unsupported

https://pytorch.org/cppdocs/

https://pytorch.org/tutorials/advanced/cpp_export.html
























直播预告











【他山之石】在C++平台上部署PyTorch模型流程+踩坑实录










左划查看更多






【他山之石】在C++平台上部署PyTorch模型流程+踩坑实录






【他山之石】在C++平台上部署PyTorch模型流程+踩坑实录


























历史文章推荐











































分享、点赞、在看,给个三连击呗!








推荐阅读
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • Java实战之电影在线观看系统的实现
    本文介绍了Java实战之电影在线观看系统的实现过程。首先对项目进行了简述,然后展示了系统的效果图。接着介绍了系统的核心代码,包括后台用户管理控制器、电影管理控制器和前台电影控制器。最后对项目的环境配置和使用的技术进行了说明,包括JSP、Spring、SpringMVC、MyBatis、html、css、JavaScript、JQuery、Ajax、layui和maven等。 ... [详细]
  • Oracle seg,V$TEMPSEG_USAGE与Oracle排序的关系及使用方法
    本文介绍了Oracle seg,V$TEMPSEG_USAGE与Oracle排序之间的关系,V$TEMPSEG_USAGE是V_$SORT_USAGE的同义词,通过查询dba_objects和dba_synonyms视图可以了解到它们的详细信息。同时,还探讨了V$TEMPSEG_USAGE的使用方法。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 【shell】网络处理:判断IP是否在网段、两个ip是否同网段、IP地址范围、网段包含关系
    本文介绍了使用shell脚本判断IP是否在同一网段、判断IP地址是否在某个范围内、计算IP地址范围、判断网段之间的包含关系的方法和原理。通过对IP和掩码进行与计算,可以判断两个IP是否在同一网段。同时,还提供了一段用于验证IP地址的正则表达式和判断特殊IP地址的方法。 ... [详细]
  • MPLS VP恩 后门链路shamlink实验及配置步骤
    本文介绍了MPLS VP恩 后门链路shamlink的实验步骤及配置过程,包括拓扑、CE1、PE1、P1、P2、PE2和CE2的配置。详细讲解了shamlink实验的目的和操作步骤,帮助读者理解和实践该技术。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了C#中生成随机数的三种方法,并分析了其中存在的问题。首先介绍了使用Random类生成随机数的默认方法,但在高并发情况下可能会出现重复的情况。接着通过循环生成了一系列随机数,进一步突显了这个问题。文章指出,随机数生成在任何编程语言中都是必备的功能,但Random类生成的随机数并不可靠。最后,提出了需要寻找其他可靠的随机数生成方法的建议。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • 本文介绍了iOS数据库Sqlite的SQL语句分类和常见约束关键字。SQL语句分为DDL、DML和DQL三种类型,其中DDL语句用于定义、删除和修改数据表,关键字包括create、drop和alter。常见约束关键字包括if not exists、if exists、primary key、autoincrement、not null和default。此外,还介绍了常见的数据库数据类型,包括integer、text和real。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • 本文介绍了在wepy中运用小顺序页面受权的计划,包含了用户点击作废后的从新受权计划。 ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
author-avatar
mobiledu2502883183
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有