热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TorchScript是如何加速RNN的?

本文改写自OptimizingCUDARecurrentNeuralNetworkswithTorchScript,有删改。RNN在多种NLP任务上有良好表现,PyTorch自带了

本文改写自 Optimizing CUDA Recurrent Neural Networks with TorchScript,有删改。

RNN 在多种 NLP 任务上有良好表现,PyTorch 自带了几种 RNN 的典型实现(例如 Elman RNN,GRU,LSTM 以及它们的 multi-layered 和 bidirectional 版本)。

有时用户想修改 RNN 的一些实现细节。例如,用户想将 Layer Normalization 应用到 LSTM 中,这一改动很难实现,因为 PyTorch CUDA LSTM 是高度一体化的。此时用户会尝试使用 PyTorch 的基本 Operator 来定制他们想要的 LSTM,这会带来开销:RNN 大量的使用 Operator,而多数 Operator 会在 GPU 上启动至少一个 kernel。

可以使用 TorchScript 来改善这种情况:它可以优化代码并 Fuse Operation,以降低在 GPU 上启动的 kernel 个数,并使 kernel 质量更高。

本文的目标是让用户自然、快速的实现 RNN,并达到和手工优化的 CUDA kernel 同样的性能。

本文代码链接github.com

Operator Fuse 的局限性

如果不了解 Operator Fuse,可参考这篇文章:内核融合:GPU深度学习的“加速神器”

PyTorch JIT 能将相邻的 element-wise 操作 Fuse 到一个 FusionGroup 中,这个 FusionGroup 只会启动一个 GPU/CPU kernel。

这意味着,如果使用了较为复杂的 Operator(例如:混合了 element-wise 的 reduce 操作),JIT 识别可 Fuse 的 Operator 时会遇到困难。此时,可以尝试分离 reduce 操作 与 element-wise 操作,这样做的话,JIT 就可以将几个 element-wise 操作 Fuse 在一个 Fusion Group 中了。

在本文 LSTM Cell(forward)一节中,我们可以看到 PyTorch JIT 会尝试着在保证程序正确性的情况下,将 element-wise 操作尽可能放在一起,从而进行 Operator Fuse。

PyTorch JIT 对 LSTM 的优化

LSTM Cell(forward)

几乎 LSTM 中所有的计算都发生在 LSTMCell 中,下面是 LSTMCell 在 TorchScript 下的一种实现:

class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)

重排 Chunk 和 Pointwise Ops,挖掘更多的 Fuse 机会

上面实现的 LSTMCell

  1. 将 gate 加(Pointwise Ops)在一起
  2. 将加在一起的结果 chunk 为四小块
  3. 对每个小块执行激活函数(Pointwise Ops)

如果不重排 Chunk 和 Pointwise Ops,这个实现会产生两个 Fusion Group(chunk 之前和 chunk 之后)。

为了避免这个问题,JIT 调换了第一个 Pointwise Ops 和 chunk 的作用对象和顺序(如图):

调换前:对输入执行 Pointwise Ops,对输出执行 Chunk(图左侧路径)

调换后:对输入执行 Chunk,对输出执行 Pointwise Ops(图右侧路径)

《TorchScript 是如何加速 RNN 的?》
《TorchScript 是如何加速 RNN 的?》

通过这一重排,原来的两个 Fusion Group 就可以 Fuse 为一个了,我们用 graph_for 来看 JIT 将哪些操作 Fuse 到了一起:

# get inputs and states for LSTMCell
inputs = get_lstm_inputs()
# instantiate a ScriptModule
cell = LSTMCell(input_size, hidden_size)
# print the optimized graph using graph_for
out = cell(inputs)
print(cell.graph_for(inputs))

输出是优化过的 TorchScript Graph(又名 PyTorch JIT IR[1]

graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30)
with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *),
%29 : Float(*),
%33 : Float(*),
%40 : Float(*, *),
%43 : Float(*, *),
%45 : Float(*, *),
%48 : Float(*, *)):
%49 : Float(*, *) = aten::t(%48)
%47 : Float(*, *) = aten::mm(%45, %49)
%44 : Float(*, *) = aten::t(%43)
%42 : Float(*, *) = aten::mm(%40, %44)
...some broadcast sizes operations...
%hy : Float(*, *), %287 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) =prim::FusionGroup_0(%13, %346, %345, %344, %343)
...some broadcast sizes operations...
return (%hy, %cy, %49, %44, %196, %199, %340, %192, %325, %185, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %395, %396, %287)
with prim::FusionGroup_0 = graph(%13 : Float(*, *),
%71 : Tensor,
%76 : Tensor,
%81 : Tensor,
%86 : Tensor):
...some chunks, constants, and add operations...
%ingate.1 : Float(*, *) = aten::sigmoid(%38)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%34)
%cellgate.1 : Float(*, *) = aten::tanh(%30)
%outgate.1 : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
%11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%cy : Float(*, *) = aten::add(%14, %11, %69)
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %4)
return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)

从输出可以看到,子图 prim::FusionGroup_0 Fuse 了 LSTMCell 的所有 element-wise 操作(sigmoid,tanh,mul,add)(transpose 和 matrix multiplication 不是 element-wise 操作)。最后生成的 IR 中只有一个 FusionGroup,而不是两个。

LSTM Layer

下面是 LSTM Layer 在 TorchScript 下的一种实现:

class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(LSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state

forward

循环展开(Loop Unrolling)

JIT 将代码中的循环展开了(对于大循环,仅展开一小部分),这使得 Fuser 可以在多个迭代步骤的循环体上做 Fuse,挖掘对 For Loops Control Flow 更多的优化机会。

Batch Matrix Multiplication

将多个矩阵乘法 Batch 在一起执行:

《TorchScript 是如何加速 RNN 的?》
《TorchScript 是如何加速 RNN 的?》

backward

Automatic Differentiation

不同于 PyTorch 原有的,基于 Tape 的 Autograd;TorchScript 使用了基于符号微分的自动微分机制,这使得运行时不必记录运算轨迹了,减轻了对存储和计算的负担。

变长序列最佳实践

TorchScript 不支持 PackedSequence。在处理变长序列时,最好将它们 pad 到同一长度放到单个 Tensor 中,再传给 TorchScript LSTM。例如:

sequences = [...] # List[Tensor], each Tensor is T' x C
padded = torch.utils.rnn.pad_sequence(sequences)
lengths = [seq.size(0) for seq in sequences]
padded # T x N x C, where N is batch size and T is the max of all T'
model = LSTM(...)
output, hiddens = model(padded)
output # T x N x C

output 可能会在 pad 区域包含无效信息,所以要使用 lengths 追踪有效部分。

参考

  1. ^PyTorch IR https://github.com/pytorch/pytorch/wiki/PyTorch-IR

推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 本文介绍了PhysioNet网站提供的生理信号处理工具箱WFDB Toolbox for Matlab的安装和使用方法。通过下载并添加到Matlab路径中或直接在Matlab中输入相关内容,即可完成安装。该工具箱提供了一系列函数,可以方便地处理生理信号数据。详细的安装和使用方法可以参考本文内容。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • Android工程师面试准备及设计模式使用场景
    本文介绍了Android工程师面试准备的经验,包括面试流程和重点准备内容。同时,还介绍了建造者模式的使用场景,以及在Android开发中的具体应用。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文介绍了Perl的测试框架Test::Base,它是一个数据驱动的测试框架,可以自动进行单元测试,省去手工编写测试程序的麻烦。与Test::More完全兼容,使用方法简单。以plural函数为例,展示了Test::Base的使用方法。 ... [详细]
  • ALTERTABLE通过更改、添加、除去列和约束,或者通过启用或禁用约束和触发器来更改表的定义。语法ALTERTABLEtable{[ALTERCOLUMNcolu ... [详细]
  • 本文讨论了clone的fork与pthread_create创建线程的不同之处。进程是一个指令执行流及其执行环境,其执行环境是一个系统资源的集合。在调用系统调用fork创建一个进程时,子进程只是完全复制父进程的资源,这样得到的子进程独立于父进程,具有良好的并发性。但是二者之间的通讯需要通过专门的通讯机制,另外通过fork创建子进程系统开销很大。因此,在某些情况下,使用clone或pthread_create创建线程可能更加高效。 ... [详细]
  • 本文介绍了如何使用Express App提供静态文件,同时提到了一些不需要使用的文件,如package.json和/.ssh/known_hosts,并解释了为什么app.get('*')无法捕获所有请求以及为什么app.use(express.static(__dirname))可能会提供不需要的文件。 ... [详细]
  • JDK源码学习之HashTable(附带面试题)的学习笔记
    本文介绍了JDK源码学习之HashTable(附带面试题)的学习笔记,包括HashTable的定义、数据类型、与HashMap的关系和区别。文章提供了干货,并附带了其他相关主题的学习笔记。 ... [详细]
  • NotSupportedException无法将类型“System.DateTime”强制转换为类型“System.Object”
    本文介绍了在使用LINQ to Entities时出现的NotSupportedException异常,该异常是由于无法将类型“System.DateTime”强制转换为类型“System.Object”所导致的。同时还介绍了相关的错误信息和解决方法。 ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
author-avatar
biosan
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有