热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch的C++extension写法

加入极市专业CV交流群,与同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注

加入极市专业CV交流群,与  1 0000+来自港科大、北大、清华、中科院、CMU、腾讯、百度  等名校名企视觉开发者互动交流!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注  极市平台  公众号  , 回复  加群, 立刻申请入群~

本文授权转自知乎作者 Monstarrrr,https://zhuanlan.zhihu.com/p/100459760。未经作者许可,不可二次转载。

2019年的最后一天,终于填了一个早就想了解的坑。就是关于pytorch如何自定义一个扩展,这里主要是说C++扩展。

首先为什么需要扩展?python调用C++的库也是可行的啊。刚开始我也在思考这个问题,觉得没有必要。但是后来深入了解了以后发现还是有必要的。举个栗子,调用始终是使用的是别人的东西,但是扩展则是通过他人的帮助来完成一个属于自己的东西。

pytorch的C++ extension和 python 的c/c++ extension其实原理差不多,本质上都是为了扩展各自的功能,当然也为了使程序运行更加有效率,差别在于pytorch的C++ extension实施步骤较python的c/c++ extension的要简化一些。

这里以实现神经网络自定义的layer为例:

先说一下基本的流程:

  • 利用C++写好自定义层发功能,主要包括前向传播和方向传播,以及pybind11的内容。

  • 写好setup.py脚本, 并利用python提供的setuptools来编译并加载C++代码。

  • 编译安装,在python中调用C++扩展接口

pybind11是python的一个库,主要负责python与C++11之间的通信

下面就以一个最简单的z=2x+y来看看如何一步步完成这样一个简单运算的layer。

第一步:编写头文件,这里就叫做test.h

/*test.h*/
#include #include 
// forward propagation
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);
// backward propagation
std::vector Test_backward_cpu(const torch::Tensor& gradOutput);

这里包含一个重要的头文件

这个头文件里面包含很多重要的模块。如用于python和C++11交互的pybind11,以及包含Tensor的一系列定义操作,因为pytorch的基本数据单元是Tensor。

头文件写完以后就要开始写源文件了test.cpp

/*test.cpp*/
#include "test.h"
// part1:forward propagation
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y)
{
    AT_ASSERTM(x.sizes() == y.sizes());
    torch::Tensor z = torch::zeros(x.sizes());
    z = 2 * x + y;
    return z;
}

//part2:backward propagation
std::vector Test_backward_cpu(const torch::Tensor& gradOutput)
{
    torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());
    torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());
    return {gradOutputX, gradOutputY};
}

// part3:pybind11 (将python与C++11进行绑定, 注意这里的forward,backward名称就是后来在python中可以引用的方法名)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("forward", &Test_forward_cpu, "Test forward");
    m.def("backward", &Test_backward_cpu, "Test backward");
}

源文件cpp里面包含了三个部分,第一个部分是forward函数,第二个部分是backward函数,第三个部分是pytorch和C++交互的部分。

至此C++部分的工作就完成了。也就是我们的步骤一:利用C++写好自定义层发功能,主要包括前向传播和方向传播,以及pybind11的内容。

下面的工作就是pytorch如何识别和使用这个扩展程序了。

第二步:编写setup.py,这个文件的主要作用是用来编译C++文件以及建立链接关系。

现在的文件目录排布为:

pytorch的C++ extension写法

setup.py中的内容为:

from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension

# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))
#源代码目录
source_file = glob.glob(os.path.join(working_dirs, 'src', '*.cpp'))

setup(
    name='test_cpp',  # 模块名称
    ext_modules=[CppExtension('test_cpp', sources=source_file, include_dirs=[include_dirs])],
    cmdclass={
        'build_ext': BuildExtension
    }
)

这一部分基本上算是一个固定的格式针对不同的问题需要修改的地方就是ext_modules参数,这里面根据实际的需要列表中可以存在多个CppExtension模块,也就是说可以同时编译多个C++文件。

例如像这样:

pytorch的C++ extension写法

完成setup.py以后,需要在终端执行python setup.py install

NOTE:建议将扩展安装在个人虚拟环境中

这一步其实是包含了build+install执行的是先编译链接动态链接库,然后将构建好的文件以package的形式安装存放再 当前开发环境 的package的集中存放处,这样就相当于生成了一个完整的package了。和其他的如numpy,torch这些package没什么两样。

执行完这一步后就生成了这一堆东西:

pytorch的C++ extension写法

这样,我们的第二步“写好setup.py脚本, 并利用python提供的setuptools来编译并加载C++代码。”也完成了。

NOTE:此时如果在python的控制台中输入import test_cpp会得到这样的错误:

undefined symbol: _ZTIN3c1021AutogradMetaInterfaceE

原因是因为它还没有封装起来,暂时还见不得人~。

下面是最后一步:封装调用这个扩展(extension),先在与setup.py相同的目录下新建一个test.py

内容为:

from torch.autograd import Function
import torch
import test_cpp


class _TestFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        """
        It must accept a context ctx as the first argument, followed by any
        number of arguments (tensors or other types).
        The context can be used to store tensors that can be then retrieved
        during the backward pass."""
        return test_cpp.forward(x, y)

    @staticmethod
    def backward(ctx, gradOutput):
        gradX, gradY = test_cpp.backward(gradOutput)
        return gradX, gradY

# 封装成一个模块(Module)
class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()

    def forward(self, inputA, inputB):
        return _TestFunction.apply(inputA, inputB)

这是pytorch的autograd中的一个扩展函数的接口模板。基本pytorch中所有的层的前向传播和反向传播都是这样写的。

关于pytorch的方向传播的细节,有两个需要注意的点。其一,forward函数中有一个ctx变量,这是一定需要的。因为这里面会存一些对方向传播有用的变量(因为有些函数求导是需要用到前向计算过程中的一些计算结果)。backward中也有ctx参数,可以获取从forward函数中所保存的变量。第二个需要注意的点就是backward输出的都是关于变量的梯度,其数目要和forwad中输入的一致,这是一种强制性的要求,如果有些变量不需要求导,就直接返回None即可。

一切就绪以后,可以开始使用了,但是在这之前还需要确定你写的反向传播层的梯度否计算正确。pytorch提供了一个torch.autograd.gradcheck()函数来检查的所计算的梯度是否合理。这个检查的原理是通过比较梯度的数值计算和解析表达之间的误差来判断梯度计算是否正确:

梯度的数值计算法:

pytorch的C++ extension写法

梯度的解析法就是我们通过求导公式计算得到的。如:

pytorch的C++ extension写法

这一步检查无误以后就可以happy的使用这个模块了。也就是说完整的完成了一个pytorch的c++扩展。

总结一下:首先要写C++源码程序,需要使用一个torch的库。这个库里面规定了如何利用这个库来写C++的extension,里面的基本数据格式为Tensor类型,不是一般的int/char/float类型。需要在C++源码中写一个forward函数和backward函数,在C++源码的最后使用PYBIND11来进行C++11和python的对话。之后便是编写setup.py文件,使用python提供的setuptools和torch自带的BuildExtension和CppExtension工具来进行编译的准备工作。然后再命令行中键入python setup.py install(根据需要使用build或者install,如你不想把你的package安装到系统路径中去,也就是site-package中,那么就用build命令,反之就用install命令),编译完成后还需要使用torch.autograd.Function来将这个扩展写成一个函数,方便在构建网络的时候调用。最后就在合适的地方使用Function.apply(*args)。这样一个定制化的模块就搞定了,也就是说完成了一个完整的pytorch的C++扩展(让用户丝毫感觉不到这个代码是在C++上扩展的~,但是作为一名算法菜狗,了解这些过程还是有必要的,这本身就是问题的一部分,包括如何写新网络中新的layer,毕竟看到再多,如果不自己动手始终差些感觉)

最后补充一点,关于求导,标量对标量的求导就不用多说了,这里主要是标量对向量/矩阵的求导。

pytorch的C++ extension写法
page1
pytorch的C++ extension写法
page2

这里使用到的标量对于矩阵的求导方法可以参见以下文章,写的非常好。

https://zhuanlan.zhihu.com/p/24709748

-END -

推荐阅读:

  • 你有哪些deep learning(rnn、cnn)调参的经验?

  • 实践经验分享:在深度学习中喂饱GPU

  • 神经网络训练tricks

极市独家福利

40万奖金的AI移动应用大赛,参赛就有奖,入围还有额外奖励

pytorch的C++ extension写法

添加极市小助手微信 (ID : cv-mart) ,备注: 研究方向-姓名-学校/公司-城市 (如:目标检测-小极-北大-深圳),即可申请加入 目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群 ,更有 每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、 干货资讯汇总、行业技术交流 一起来让思想之光照的更远吧~

pytorch的C++ extension写法

△长按添加极市小助手

pytorch的C++ extension写法

△长按关注极市平台,获取 最新CV干货

觉得有用麻烦给个在看啦~   


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 我们


推荐阅读
  • 本文介绍了Python对Excel文件的读取方法,包括模块的安装和使用。通过安装xlrd、xlwt、xlutils、pyExcelerator等模块,可以实现对Excel文件的读取和处理。具体的读取方法包括打开excel文件、抓取所有sheet的名称、定位到指定的表单等。本文提供了两种定位表单的方式,并给出了相应的代码示例。 ... [详细]
  • 安装mysqlclient失败解决办法
    本文介绍了在MAC系统中,使用django使用mysql数据库报错的解决办法。通过源码安装mysqlclient或将mysql_config添加到系统环境变量中,可以解决安装mysqlclient失败的问题。同时,还介绍了查看mysql安装路径和使配置文件生效的方法。 ... [详细]
  • 搭建Windows Server 2012 R2 IIS8.5+PHP(FastCGI)+MySQL环境的详细步骤
    本文详细介绍了搭建Windows Server 2012 R2 IIS8.5+PHP(FastCGI)+MySQL环境的步骤,包括环境说明、相关软件下载的地址以及所需的插件下载地址。 ... [详细]
  • Python实现变声器功能(萝莉音御姐音)的方法及步骤
    本文介绍了使用Python实现变声器功能(萝莉音御姐音)的方法及步骤。首先登录百度AL开发平台,选择语音合成,创建应用并填写应用信息,获取Appid、API Key和Secret Key。然后安装pythonsdk,可以通过pip install baidu-aip或python setup.py install进行安装。最后,书写代码实现变声器功能,使用AipSpeech库进行语音合成,可以设置音量等参数。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 计算机存储系统的层次结构及其优势
    本文介绍了计算机存储系统的层次结构,包括高速缓存、主存储器和辅助存储器三个层次。通过分层存储数据可以提高程序的执行效率。计算机存储系统的层次结构将各种不同存储容量、存取速度和价格的存储器有机组合成整体,形成可寻址存储空间比主存储器空间大得多的存储整体。由于辅助存储器容量大、价格低,使得整体存储系统的平均价格降低。同时,高速缓存的存取速度可以和CPU的工作速度相匹配,进一步提高程序执行效率。 ... [详细]
  • 本文介绍了计算机网络的定义和通信流程,包括客户端编译文件、二进制转换、三层路由设备等。同时,还介绍了计算机网络中常用的关键词,如MAC地址和IP地址。 ... [详细]
  • 本文介绍了在Windows环境下如何配置php+apache环境,包括下载php7和apache2.4、安装vc2015运行时环境、启动php7和apache2.4等步骤。希望对需要搭建php7环境的读者有一定的参考价值。摘要长度为169字。 ... [详细]
  • Android源码深入理解JNI技术的概述和应用
    本文介绍了Android源码中的JNI技术,包括概述和应用。JNI是Java Native Interface的缩写,是一种技术,可以实现Java程序调用Native语言写的函数,以及Native程序调用Java层的函数。在Android平台上,JNI充当了连接Java世界和Native世界的桥梁。本文通过分析Android源码中的相关文件和位置,深入探讨了JNI技术在Android开发中的重要性和应用场景。 ... [详细]
  • C++字符字符串处理及字符集编码方案
    本文介绍了C++中字符字符串处理的问题,并详细解释了字符集编码方案,包括UNICODE、Windows apps采用的UTF-16编码、ASCII、SBCS和DBCS编码方案。同时说明了ANSI C标准和Windows中的字符/字符串数据类型实现。文章还提到了在编译时需要定义UNICODE宏以支持unicode编码,否则将使用windows code page编译。最后,给出了相关的头文件和数据类型定义。 ... [详细]
  • Java在运行已编译完成的类时,是通过java虚拟机来装载和执行的,java虚拟机通过操作系统命令JAVA_HOMEbinjava–option来启 ... [详细]
  • 本文介绍了2020年计算机二级MSOffice的选择习题及答案,详细解析了操作系统的五大功能模块,包括处理器管理、作业管理、存储器管理、设备管理和文件管理。同时,还解答了算法的有穷性的含义。 ... [详细]
  • 本文介绍了深入浅出Linux设备驱动编程的重要性,以及两种加载和删除Linux内核模块的方法。通过一个内核模块的例子,展示了模块的编译和加载过程,并讨论了模块对内核大小的控制。深入理解Linux设备驱动编程对于开发者来说非常重要。 ... [详细]
  • 本文详细介绍了GetModuleFileName函数的用法,该函数可以用于获取当前模块所在的路径,方便进行文件操作和读取配置信息。文章通过示例代码和详细的解释,帮助读者理解和使用该函数。同时,还提供了相关的API函数声明和说明。 ... [详细]
  • 本文介绍了蓝桥训练中的闰年判断问题,并提供了使用Python代码进行判断的方法。根据给定的年份,判断是否为闰年的条件是:年份是4的倍数且不是100的倍数,或者是400的倍数。根据输入的年份,输出结果为yes或no。本文提供了相应的Python代码实现。 ... [详细]
author-avatar
隔壁老吴
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有