本文改写自OptimizingCUDARecurrentNeuralNetworkswithTorchScript,有删改。RNN在多种NLP任务上有良好表现,PyTorch自带了
本文改写自 Optimizing CUDA Recurrent Neural Networks with TorchScript,有删改。
RNN 在多种 NLP 任务上有良好表现,PyTorch 自带了几种 RNN 的典型实现(例如 Elman RNN,GRU,LSTM 以及它们的 multi-layered 和 bidirectional 版本)。
有时用户想修改 RNN 的一些实现细节。例如,用户想将 Layer Normalization 应用到 LSTM 中,这一改动很难实现,因为 PyTorch CUDA LSTM 是高度一体化的。此时用户会尝试使用 PyTorch 的基本 Operator 来定制他们想要的 LSTM,这会带来开销:RNN 大量的使用 Operator,而多数 Operator 会在 GPU 上启动至少一个 kernel。
可以使用 TorchScript 来改善这种情况:它可以优化代码并 Fuse Operation,以降低在 GPU 上启动的 kernel 个数,并使 kernel 质量更高。
本文的目标是让用户自然、快速的实现 RNN,并达到和手工优化的 CUDA kernel 同样的性能。
本文代码链接 github.com
Operator Fuse 的局限性
如果不了解 Operator Fuse,可参考这篇文章:内核融合:GPU深度学习的“加速神器”
PyTorch JIT 能将相邻的 element-wise 操作 Fuse 到一个 FusionGroup 中,这个 FusionGroup 只会启动一个 GPU/CPU kernel。
这意味着,如果使用了较为复杂的 Operator(例如:混合了 element-wise 的 reduce 操作),JIT 识别可 Fuse 的 Operator 时会遇到困难。此时,可以尝试分离 reduce 操作 与 element-wise 操作,这样做的话,JIT 就可以将几个 element-wise 操作 Fuse 在一个 Fusion Group 中了。
在本文 LSTM Cell(forward)
一节中,我们可以看到 PyTorch JIT 会尝试着在保证程序正确性的情况下,将 element-wise 操作尽可能放在一起,从而进行 Operator Fuse。
PyTorch JIT 对 LSTM 的优化
LSTM Cell(forward)
几乎 LSTM 中所有的计算都发生在 LSTMCell 中,下面是 LSTMCell 在 TorchScript 下的一种实现:
class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
重排 Chunk 和 Pointwise Ops,挖掘更多的 Fuse 机会
上面实现的 LSTMCell
- 将 gate 加(Pointwise Ops)在一起
- 将加在一起的结果 chunk 为四小块
- 对每个小块执行激活函数(Pointwise Ops)
如果不重排 Chunk 和 Pointwise Ops,这个实现会产生两个 Fusion Group(chunk 之前和 chunk 之后)。
为了避免这个问题,JIT 调换了第一个 Pointwise Ops 和 chunk 的作用对象和顺序(如图):
调换前:对输入执行 Pointwise Ops,对输出执行 Chunk(图左侧路径)
调换后:对输入执行 Chunk,对输出执行 Pointwise Ops(图右侧路径)
通过这一重排,原来的两个 Fusion Group 就可以 Fuse 为一个了,我们用 graph_for 来看 JIT 将哪些操作 Fuse 到了一起:
# get inputs and states for LSTMCell
inputs = get_lstm_inputs()
# instantiate a ScriptModule
cell = LSTMCell(input_size, hidden_size)
# print the optimized graph using graph_for
out = cell(inputs)
print(cell.graph_for(inputs))
输出是优化过的 TorchScript Graph(又名 PyTorch JIT IR[1])
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30)
with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *),
%29 : Float(*),
%33 : Float(*),
%40 : Float(*, *),
%43 : Float(*, *),
%45 : Float(*, *),
%48 : Float(*, *)):
%49 : Float(*, *) = aten::t(%48)
%47 : Float(*, *) = aten::mm(%45, %49)
%44 : Float(*, *) = aten::t(%43)
%42 : Float(*, *) = aten::mm(%40, %44)
...some broadcast sizes operations...
%hy : Float(*, *), %287 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) =prim::FusionGroup_0(%13, %346, %345, %344, %343)
...some broadcast sizes operations...
return (%hy, %cy, %49, %44, %196, %199, %340, %192, %325, %185, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %395, %396, %287)
with prim::FusionGroup_0 = graph(%13 : Float(*, *),
%71 : Tensor,
%76 : Tensor,
%81 : Tensor,
%86 : Tensor):
...some chunks, constants, and add operations...
%ingate.1 : Float(*, *) = aten::sigmoid(%38)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%34)
%cellgate.1 : Float(*, *) = aten::tanh(%30)
%outgate.1 : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
%11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%cy : Float(*, *) = aten::add(%14, %11, %69)
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %4)
return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
从输出可以看到,子图 prim::FusionGroup_0 Fuse 了 LSTMCell 的所有 element-wise 操作(sigmoid,tanh,mul,add)(transpose 和 matrix multiplication 不是 element-wise 操作)。最后生成的 IR 中只有一个 FusionGroup,而不是两个。
LSTM Layer
下面是 LSTM Layer 在 TorchScript 下的一种实现:
class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(LSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state
forward
循环展开(Loop Unrolling)
JIT 将代码中的循环展开了(对于大循环,仅展开一小部分),这使得 Fuser 可以在多个迭代步骤的循环体上做 Fuse,挖掘对 For Loops Control Flow
更多的优化机会。
Batch Matrix Multiplication
将多个矩阵乘法 Batch 在一起执行:
backward
Automatic Differentiation
不同于 PyTorch 原有的,基于 Tape 的 Autograd;TorchScript 使用了基于符号微分的自动微分机制,这使得运行时不必记录运算轨迹了,减轻了对存储和计算的负担。
变长序列最佳实践
TorchScript 不支持 PackedSequence。在处理变长序列时,最好将它们 pad 到同一长度放到单个 Tensor 中,再传给 TorchScript LSTM。例如:
sequences = [...] # List[Tensor], each Tensor is T' x C
padded = torch.utils.rnn.pad_sequence(sequences)
lengths = [seq.size(0) for seq in sequences]
padded # T x N x C, where N is batch size and T is the max of all T'
model = LSTM(...)
output, hiddens = model(padded)
output # T x N x C
output
可能会在 pad 区域包含无效信息,所以要使用 lengths
追踪有效部分。
参考
- ^PyTorch IR https://github.com/pytorch/pytorch/wiki/PyTorch-IR