上一篇文章我讲了如何利用运算符重载(operator overloading)实现自动微分。
罗秀哲:一天实现自己的自动微分zhuanlan.zhihu.com但是如果你有使用过诸如PyTorch,Flux/Tracker,AutoGrad这一类基于运算符重载的自动微分库,就会发现,这些库有两个通病:
在Yan LeCun等人的号召下,想要实现可微分编程(Differentiable Programming)如果没有上面这两个功能可不行。不能对控制流进行微分叫什么可微分编程?因为我们希望我们编写的任意在数学上成立的程序都可以进行自动微分。此外上篇文章的评论区里有人提到了TensorFlow的自动产生符号导数然后加入计算图的功能。而实际上如果对编译器数学的同学会知道这类符号计算就是一个简单的编译步骤。
这些问题在源对源(source to source)的自动微分下实现起来将非常自然,这篇文章将实现一个不带控制流的简单版本,而完整的版本已经在Julia中通过staged programming的方式实现了,这个包叫做 Zygote.jl [1]。它中文翻译很搞笑,叫卵子。使用这个包的人都用了个卵子自动微分。我在今年JuliaCon的周五Hackthon期间在Zygote的作者Mike的帮助下实现了一个简单版本的。
如果你想要阅读更详细的代码的,将这篇文章实现的简单版本放在了这个repo里,
Roger-luo/YASSAD.jlgithub.com英文版在我的blog:
http://blog.rogerluo.me/2019/07/27/yassad/blog.rogerluo.me不过,在我们开始编写程序前,让我们来回顾一些基础知识。
首先让我们来简单了解一下Julia语言是如何进行编译的。
到此位置我们完成了代码的初始化过程。
为了描述Julia是如何编译的,我从之前的JuliaCon的报告里拿出来一副图。
这张图更清楚一些,你可以看到与静态语言不同的是,每次函数调用都会经过编译过程。Julia中的编译(包括JIT编译)是以函数调用为界的。
完整的介绍SSA需要很大的篇幅,足够写一本书了[2]。 但是不用担心,我们这里需要用到的部分很简单。你只需要知道下面几个概念就可以了
如果你已经阅读过我上篇文章,我相信你已经了解计算图这个概念,但是现在我们要重新思考什么是计算图。我们回顾一下上一节用到的图
在进行AD的计算的时候,我们将计算过程表示为一个计算图。每个节点都会使用一个运算符(operator)然后获得一个中间值(一个节点),接下来这个节点的值会和函数一起暂时存起来,在后向传播的时候使用。也就是说每个节点的中间变量都只会被赋值一次,否则就不能唯一对应一个算符。而每个节点有两个函数,一个是代表前向计算的函数,另外一个则是代表后向传播的导数函数。
很自然的,我想你已经发现了,这就是一个SSA的格式。而所谓的自动微分,其实就是我们正常的前向程序的某一个对偶的程序(也就是一个对偶的函数)。而实际上,没有控制流的计算图我们称之为 Wengert Lists [3]。大部分基于运算符重载的自动微分实际上都是实现了这个算法,有时候它也被称为Tape。而SSA格式则更加一般,它包含了控制流。所以我们可以通过对SSA格式来计算自动微分来实现对控制流的自动微分。这也是Zygote第一篇文章所提到的方法[4]。
而由于后向传播的函数只是前向传播的函数的伴随(adjoint,实际上一些数学家也认为后向传播可以定义在一个对偶空间上,所以我们不妨就使用这个称谓)。我们不妨直接将这个函数写做一个前向传播函数+一个闭包的格式。
function forward(::typeof(your_function), xs...)# 函数声明output = # 函数输出output, function (Δ)# 这是一个闭包(或者你可以理解成一个能够获取forward函数的local变量的匿名函数)end
end
实现成闭包的好处是实现一些需要使用前向传播的中间值的导数的时候我们可以把这些中间值以闭包函数的状态(state)的方式托管给编译器,而不需要像我在上篇文章里一样,手动将其存在一个Node对象中。我们称这个返回的闭包函数为pullback。
所以假如我们想要获得一个下面这个函数的导数
foo(x) = bar(baz(x))
如果手动做这件事情,我们只需要定义一个 forward 函数
function forward(::typeof(foo), x)x1, back1 = forward(baz, x)x2, back2 = forward(bar, x1)return x2, function (Δ)dx1 = back2(Δ)dx2 = back1(dx1)return dx2end
end
实际上,一般的来说,一段没有控制流的程序的伴随,就是倒着把这段程序中的所有函数调用换成伴随程序的调用,变量换成其对应的伴随变量(adjoint variable)。但是我们如何通过一个函数定义来产生这个forward函数呢?有人可能会说宏,但是宏会要求我们在所有可以进行求导的函数前面都要这样标记,这是我们不希望的,我们希望未来使用这个自动微分的时候我们不需要写任何额外的东西。
而由于SSA格式的IR已经将所有的语法糖,函数都转换成低级表示了,这也就意味着仅仅需要定义一些原始表示,我们就可以用上面这个规则(实际上就是链式法则)组合出非常多的导数,而这些导数的生成都发生在编译时期,所以不会反复占用运行时间,并且这也能帮助我们未来进一步进行一些有针对性的优化。
所以我们想在SSA IR上来做这件事,但是怎么做呢?我们知道宏可以用来修改代码的解析过程(parsing),而Julia里还有另外一个元素用来实现对类型推导和IR的修改,生成函数(generated function)。生成函数可以通过一个宏来声明
@generated function foo(a, b, c)return :(1 + 1)
end
它看起来像是一个普通函数,但是注意它是发生在类型推导期间的
所以你只能知道函数变量 a, b, c的类型信息,我们可以通过类型信息产生两种格式的代码,一种是AST表达式,在Julia里叫Expr,另外一种就是我们的SSA IR,是一种叫CodeInfo的Julia对象。IRTools[5]里提供了操作SSA IR的一些工具,我们将使用这个包来编写产生这个forward函数的代码。
我们可以通过 code_ir 宏来拿到函数的ir对象,这个对象是被IRTools处理过的,它的类型是IR。和 code_typed 宏或者 code_lowered 宏得到的对象不同的是,IR类型实现了一些方便的函数操作,并且IR类型中不会保存变量的名称,所有的变量都用 %数字 来表示
julia> @code_ir foo(1.0)
1: (%1, %2)%3 = (Main.baz)(%2)%4 = (Main.bar)(%3)return %4
注意,你会发现,即便这里 baz和bar这两个函数没定义也不会报错,这是因为Julia本质上还是一个动态语言,所以只要在真正运行的时候才会报错。
这个格式下,每一行代码都绑定了一个变量,等号右边我们称之为声明(statement),左边是变量(variable)。你可以用类似字典的接口来使用这个对象,例如
julia> using IRTools: varjulia> ir[var(3)]
IRTools.Statement(:((Main.baz)(%2)), Any, 1)
它会给你一个声明对象,这个对象里记录了这段声明的表达式,这个宏给你的是没有经过类型推导的IR。所以后面的Any就是这个变量的类型。Any也是Julia中唯一的静态类型。为了简单起见,我们这里不介绍带类型的IR(因为原理上是类似的但是实现细节有一些不同)。最后数字1是指这段声明所在的行号。
前面的1是什么意思呢?在SSA格式中我们用这样的代码块来表示分支,我们不妨写一个ifelse语句看看
julia> function foo(x)if x > 1bar(x)elsebaz(x)endend
foo (generic function with 1 method)julia> @code_ir foo(1.0)
1: (%1, %2)%3 = %2 > 1br 3 unless %3
2:%4 = (Main.bar)(%2)return %4
3:%5 = (Main.baz)(%2)return %5
ifelse在低级表示中是通过branch语句表示的,实际上循环也是类似的。Julia里的循环只是对iterate函数的语法糖而已。所以我们只要能够对br语句进行微分,我们就可以对控制流微分了。
julia> function foo(x)for x in 1:10bar(x)endbaz(x)end
foo (generic function with 1 method)julia> @code_ir foo(1.0)
1: (%1, %2)%3 = 1:10%4 = (Base.iterate)(%3)%5 = %4 === nothing%6 = (Base.not_int)(%5)br 3 unless %6br 2 (%4)
2: (%7)%8 = (Core.getfield)(%7, 1)%9 = (Core.getfield)(%7, 2)%10 = (Main.bar)(%8)%11 = (Base.iterate)(%3, %9)%12 = %11 === nothing%13 = (Base.not_int)(%12)br 3 unless %13br 2 (%11)
3:%14 = (Main.baz)(%2)return %14
那么这个IR是怎么获得的呢?为了获得IR,我们首先需要知道这个通用函数(generic function)被派发了哪个方法(method),在Julia里每个通用函数都有一个方法表(method table)你可以通过这个函数的类型标签来获得具体的方法。例如这个foo函数,每次调用 foo(1.0) 的时候,Julia都会产生下面的标签
Tuple{typeof(foo), Float64}
然后查找这个类型标签对应的方法,然后进行编译。IRTools里提供了一个meta函数来查找这些信息。
julia> T = Tuple{typeof(foo), Float64}
Tuple{typeof(foo),Float64}julia> m = IRTools.meta(T)
Metadata for foo(x) in Main at REPL[2]:1
meta中存储了比较详细的IR信息,函数变量信息等。我们可以通过这个meta信息构建一个可以操作的IR对象。
julia> IRTools.IR(m)
1: (%1, %2)%3 = (Main.baz)(%2)%4 = (Main.bar)(%3)return %4
例如我们可以向这个IR里插入一个Expr,就好像是操作一个列表一样
julia> push!(ir, :(1+1))
%5julia> ir
1: (%1, %2)%3 = (Main.baz)(%2)%4 = (Main.bar)(%3)%5 = 1 + 1return %4
IRTools会自动为你增加变量名。同理我们可以使用insert!操作在第四个变量的前面插入一行
julia> using IRTools: varjulia> insert!(ir, var(4), :(1+1))
%5julia> ir
1: (%1, %2)%3 = (Main.baz)(%2)%5 = 1 + 1%4 = (Main.bar)(%3)return %4
或者我们可以在第四个变量后面插入另外一行
julia> using IRTools: insertafter!julia> insertafter!(ir, var(4), :(2+2))
%6julia> ir
1: (%1, %2)%3 = (Main.baz)(%2)%5 = 1 + 1%4 = (Main.bar)(%3)%6 = 2 + 2return %4
有了这些工具我们基本上可以来做前向传播的IR变换了。我们目标是把每个函数调用换成对forward函数的调用,然后把forward函数返回的pullback收集起来产生一个闭包。但是等等,我好像没有讲在Julia的SSA IR里闭包是什么?我们一会儿再说这件事。我们先看看要如何把函数调用换成forward。
我们先拿出来一个声明看看它长的什么样
julia> dump(ir[var(3)])
IRTools.Statementexpr: Exprhead: Symbol callargs: Array{Any}((2,))1: GlobalRefmod: Module Mainname: Symbol baz2: IRTools.Variableid: Int64 2type: Anyline: Int64 1
那么实际上我们只要检查这个声明的表达式里的head成员是否是call就行了。但是我们不想改变我们原来的IR,IRTools里提供了一个Pipe对象用来辅助进行IR变换,这样我们可以直接往这个对象里进行插入,它会自动将这些插入操作放进一个类似的IR对象里。转换的IR存储在to成员中。Pipe一开始只会构建一个拥有同样变量的代码块
julia> IRTools.Pipe(ir).to
1: (%1, %2)
我们将这个函数命名为register,它的功能类似于我们上次实现的register函数,只是这一次不需要你手动来通过一个特别的类型派发过来了。注意我们这里在演示的时候为了方便在REPL里测试,所有的forward函数都定义在REPL里,也就是名为Main的module里如果你将这些代码写进了自己的module那么你需要修改这里的Main。
function register(ir)pr = Pipe(ir)argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)endendfinish(pr)
end
解释一下上面的代码,首先因为我们不再是调用原本的函数了,新的函数调用变成了多了一个变量的forward
forward(f, args...)
所以我们的IR也需要在最前面增加一个变量。
然后接下来我们遍历所有的SSA变量和其对应的声明(statement),如果声明的标签是一个 :call 的话说明这是一个函数调用,我们把它换成一个 forward调用。但是同时我们要保留原来代码的行号,否则报错信息会不准确。然后我们把这个声明插入到原来的变量的位置上。
但是由于forward会返回两个变量,一个是我们前向传播的值,另外一个是后向传播之后要用的pullback,所以我们需要对返回的这个tuple调用一次getindex。(因为在SSA里没法写 x, y = forward...,想想为什么?) 。但是注意getindex是不能在编译时期调用的,我们要插入一个调用getindex的表达式。这个被实现成一个 xgetindex 函数了
xgetindex(x, i...) = xcall(Base, :getindex, x, i...)
然后我们看看这个函数把IR变成什么了
julia> register(ir)
1: (%3, %1, %2)%4 = (Main.forward)(Main.baz, %2)%5 = (Base.getindex)(%4, 1)%6 = (Main.forward)(Main.bar, %5)%7 = (Base.getindex)(%6, 1)return %7
不错。我们成功把这个函数里的调用修改成了对forward函数的调用。
回过头来,我们考虑闭包的问题。是的在低级表示里,是没有闭包的。但是我们可以用一个类型把它存起来,然后让这个类型变成一个callable。这个类型会同时在类型参数S里存下函数的标签(signature),这样我们在调用pullback的时候可以查到函数的IR。
struct Pullback{S, T}data::T
endPullback{S}(data::T) where {S, T} = Pullback{S, T}(data)
这里 data 里将会存一个Tuple,里面是所有的forward返回的pullback,它们会按照调用顺序存储。
为了将pullback都存起来,我们首先要修改上面的实现,把pullback从forward函数的返回值里取出来,然后把他们都存进一个tuple然后再存到一个Pullback对象里去。因为Pullback需要函数的标签,所以我们要增加一个变量用来输入函数标签。
function register(ir, F)pr = Pipe(ir)pbs = Variable[]argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))push!(pbs, substitute(pr, J))endendpr = finish(pr)v = push!(pr, xtuple(pbs...))pbv = push!(pr, Expr(:call, Pullback{F}, v))return pr
end
这里xtuple类似于xgetindex是一个用来产生调用tuple构造函数表达式的函数
xtuple(xs...) = xcall(Core, :tuple, xs...)
最后让我们把pullback和原本的返回值打包成tuple返回
function register(ir, F)pr = Pipe(ir)pbs = Variable[]argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))push!(pbs, substitute(pr, J))endendpr = finish(pr)v = push!(pr, xtuple(pbs...))pbv = push!(pr, Expr(:call, Pullback{F}, v))ret = pr.blocks[end].branches[end].args[1]ret = push!(pr, xtuple(ret, pbv))pr.blocks[end].branches[end].args[1] = retreturn pr
end
这里return实际上也是一个分支语句,没有分支的函数return是最后一个(也是唯一一个)代码块的最后一个分支的第一个变量,我们修改它为我们要返回的tuple。
好了现在让我们来看看register会把代码变成什么
julia> register(ir, Tuple{typeof(foo), Float64})
1: (%3, %1, %2)%4 = (Main.forward)(Main.baz, %2)%5 = (Base.getindex)(%4, 1)%6 = (Base.getindex)(%4, 2)%7 = (Main.forward)(Main.bar, %5)%8 = (Base.getindex)(%7, 1)%9 = (Base.getindex)(%7, 2)%10 = (Core.tuple)(%9, %6)%11 = (Pullback{Tuple{typeof(foo),Float64},T} where T)(%10)%12 = (Core.tuple)(%8, %11)return %12
接下来让我们来实现forward函数
@generated function forward(f, xs...)# ....
end
注意在generated function里,函数变量的值都是他们的类型。我们首先产生一个函数标签来获取函数的meta
@generated function forward(f, xs...)T = Tuple{f, xs...}m = IRTools.meta(T)m === nothing && return
end
如果meta获取不到返回了nothing说明这个函数标签不存在,也就是说这个method没有定义。我们直接返回。
然后我们用我们的register函数产生这个forward函数的定义
@generated function forward(f, xs...)T = Tuple{f, xs...}m = IRTools.meta(T)m === nothing && returnfrw = register(IR(m), T)
end
但是frw的类型是IR,它并不是真正的Julia中间表示,它只是IRTools创建的一个方便操作的表示,它没有保存函数的变量名称。让我们来手动给他们加上
@generated function forward(f, xs...)T = Tuple{f, xs...}m = IRTools.meta(T)m === nothing && returnfrw = register(IR(m), T)argnames!(m, Symbol("#self#"), :f, :xs)frw = varargs!(m, frw, 2)return IRTools.update!(m, frw)
end
然后我们将forward这个函数的第二个变量xs标记为可变参数(var arg),最后调用update!使用新的meta来产生forward函数的IR。我们不妨来看看这个forward函数的IR是什么样子的
julia> @code_ir forward(foo, 1.0)
1: (%1, %2, %3)%4 = (Base.getfield)(%3, 1)%5 = (Main.forward)(Main.baz, %4)%6 = (Base.getindex)(%5, 1)%7 = (Base.getindex)(%5, 2)%8 = (Main.forward)(Main.bar, %6)%9 = (Base.getindex)(%8, 1)%10 = (Base.getindex)(%8, 2)%11 = (Core.tuple)(%10, %7)%12 = (Main.Pullback{Tuple{typeof(foo),Float64},T} where T)(%11)%13 = (Core.tuple)(%9, %12)return %13
运行一下
julia> forward(foo, 1.0)
ERROR: MethodError: no method matching getindex(::Nothing, ::Int64)
Stacktrace:[1] * at ./float.jl:399 [inlined][2] forward(::typeof(*), ::Float64, ::Float64) at /Users/roger/.julia/dev/YASSAD/src/compiler.jl:0[3] baz at ./REPL[4]:1 [inlined][4] forward(::typeof(baz), ::Float64) at /Users/roger/.julia/dev/YASSAD/src/compiler.jl:0[5] foo at ./REPL[2]:1 [inlined][6] forward(::typeof(foo), ::Float64) at /Users/roger/.julia/dev/YASSAD/src/compiler.jl:0[7] top-level scope at none:0
我们发现它会递归地调用这个forward函数,但是我们并没有为primitive地方法定义导数规则,所以它会在递归到某一步地时候报错。我们可以通过定义一些forward函数作为primitive,比如*
julia> forward(::typeof(*), a::Real, b::Real) = a * b, Δ->(Δ*b, a*Δ)julia> forward(foo, 1.0)
(1.0, YASSAD.Pullback{.....}
我们就会得到一个非常长地Pullback,但是我们还不能调用它,因为我们还没有定义Pullback是如何被调用的。
那么我们成功的将前向求值的过程变成了某种可以追踪的格式,我们接下来来产生Pullback的IR。首先类似于forward我们可以这样定义Pullback
@generated function (::Pullback{S})(delta) where Sm = IRTools.meta(S)m === nothing && returnir = IR(m)_, pbs = register(ir, S)back = adjoint(ir, pbs)argnames!(m, Symbol("#self#"), :delta)return IRTools.update!(m, back)
end
因为后向传播是分开调用的,这个时候我们已经没法直接拿走上面我们处理过的前向传播的IR了,所以我们要再次调用一遍register,但是不用担心,编译只会发生一次,所以它并不会占用我们真正的运行时间。在产生我们前向传播IR的伴随程序的时候我们还需要知道都有哪些变量产生了pullback,所以我们需要一个词典(Dict)来记录这个对应关系。我们把前面register定义修改一下
function register(ir, F)pr = Pipe(ir)pbs = Dict{Variable, Variable}()argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))pbs[v] = substitute(pr, J)endendpr = finish(pr)v = push!(pr, xtuple(values(pbs)...))pbv = push!(pr, Expr(:call, Pullback{F}, v))ret = pr.blocks[end].branches[end].args[1]ret = push!(pr, xtuple(ret, pbv))pr.blocks[end].branches[end].args[1] = retreturn pr, pbs
end
因为伴随程序和原来的IR是相反的顺序,所以我们不能用Pipe了,我们需要创建一个空的IR对象。然后手动给它增加两个变量,一个是Pullback这个对象自己,另外一个是后向传播的梯度。
adj = empty(ir)
self = argument!(adj)
delta = argument!(adj)
首先让我们从我们的对象里取出所有的pullback,这里getfield函数是语法糖 . 的低级函数调用,这句话相当于 self.data
pullbacks = pushfirst!(adj, xcall(:getfield, self, QuoteNode(:data)))
然后我们倒着遍历一遍IR中的所有变量
vars = keys(ir)for k in length(vars):-1:1v = vars[k]ex = ir[v].exprif haskey(pbs, v)pbv = insertafter!(adj, pullbacks, xcall(:getindex, pullbacks, k))g = push!(adj, Expr(:call, pbv, v))endend
如果这个变量在我们的pullback词典里的话,我们就在最前面取出来,命名这个SSA变量为pbv,然后调用它。但是这里有一个问题,如果一个变量有多个梯度(它可能参与了几个函数调用,所以我们会获得几个梯度)我们需要将这几个梯度累加到一起。所以我们需要把这些梯度按照变量记下来
grads = Dict()
然后我们实现一个grad函数,当输入两个变量的时候后面的变量是梯度变量,我们将梯度变量记录在这个变量的列表里,然后返回这个梯度变量
grad(x, x̄) = push!(get!(grads, x, []), x̄)
然后当我们输入一个变量的时候我们返回一个将这些变量累加之后的SSA变量
grad(x) = xaccum(adj, get(grads, x, [])...)
这里xaccum和前面一样是一个用来产生调用accum函数的方法。但是Julia里自带的累加accum函数和我们需要的并不一样(它定义在数组而非变量上),我们来自己定义一个。
xaccum(ir) = nothing
xaccum(ir, x) = x
xaccum(ir, xs...) = push!(ir, xcall(YASSAD, :accum, xs...))
accum() = nothing
accum(x) = x
accum(x, y) =x == nothing ? y :y == nothing ? x :x + yaccum(x, y, zs...) = accum(accum(x, y), zs...)accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)
最后我们的pullback将会返回我们每个前馈函数的变量所对应的梯度变量。而如果我们的变量没有梯度。例如forward中的第一个函数变量,在大部分情况下是没有梯度的,但是如果这个函数也是一个闭包函数,或者是一个callable就可以有梯度。所以我们现在约定每个pullback返回的梯度第一个是forward中第一个变量的梯度,它可以是nothing,pullback总是返回和forward函数变量同样数量的梯度。
所以最后我们的adjoint函数就是
function adjoint(ir, pbs)adj = empty(ir)self = argument!(adj)delta = argument!(adj)pullbacks = pushfirst!(adj, xcall(:getfield, self, QuoteNode(:data)))grads = Dict()grad(x, x̄) = push!(get!(grads, x, []), x̄)grad(x) = xaccum(adj, get(grads, x, [])...)grad(last(keys(ir)), delta)vars = keys(ir)for k in length(vars):-1:1v = vars[k]ex = ir[v].exprif haskey(pbs, v)pbv = insertafter!(adj, pullbacks, xcall(:getindex, pullbacks, k))g = push!(adj, Expr(:call, pbv, grad(v)))for (i, x) in enumerate(ex.args)x isa Variable || continuegrad(x, push!(adj, xgetindex(g, i)))endendendgs = [grad(x) for x in arguments(ir)]Δ = push!(adj, xtuple(gs...))return!(adj, Δ)return adj
end
和上篇文章一样,让我们用矩阵乘法+矩阵的迹来试试效果!我们可以直接使用Julia自带的类型了!
using LinearAlgebrafunction forward(::typeof(*), A::Matrix, B::Matrix)A * B, function (Δ::Matrix)Base.@_inline_meta(nothing, Δ * B', A' * Δ)end
endfunction forward(::typeof(tr), A::Matrix)tr(A), function (Δ::Real)Base.@_inline_meta(nothing, Δ * Matrix(I, size(A)))end
endjulia> using LinearAlgebra, BenchmarkToolsjulia> mul_tr(A::Matrix, B::Matrix) = tr(A * B)
mul_tr (generic function with 1 method)julia> A, B = rand(30, 30), rand(30, 30);julia> mul_tr(A, B)
216.7247235502547julia> z, back = forward(mul_tr, A, B);julia> julia> back(1);
而性能和我们手动实现的梯度也是基本一样的 (它其实和我们手动写出来的对于机器来说是一样的)。这是我们上次手动实现的梯度
julia> @benchmark bench_tr_mul_base($(rand(30, 30)), $(rand(30, 30)))
BenchmarkTools.Trial: memory estimate: 28.78 KiBallocs estimate: 5--------------minimum time: 10.696 μs (0.00% GC)median time: 13.204 μs (0.00% GC)mean time: 24.075 μs (43.31% GC)maximum time: 62.964 ms (99.97% GC)--------------samples: 10000evals/sample: 1
这是我们自动产生的梯度
julia> @benchmark tr_mul($A, $B)
BenchmarkTools.Trial: memory estimate: 36.17 KiBallocs estimate: 14--------------minimum time: 12.921 μs (0.00% GC)median time: 15.659 μs (0.00% GC)mean time: 27.304 μs (40.97% GC)maximum time: 60.141 ms (99.94% GC)--------------samples: 10000evals/sample: 1
到此为止我们已经实现了一个非常简单的源对源自动微分,虽然这个自动微分库已经可以使用大部分微分规则。但是在上面的IR变换过程里我们没有处理控制流。而更完整的实现Zygote可以对几乎一切Julia对象进行求导,这包括:自定义类型,控制流,外部函数调用(比如Zygote可以使用PyTorch定义的函数,所以兼容你的旧PyTorch模型),in-place函数。而例子甚至包括:部分我们的量子算法框架Yao[6]定义的量子线路,ODE求解器(NeuralODE很需要这个),光线追踪器(Ray Tracer)。详见Zygote今年NeurIPS的论文[7]。而Zygote本身的实现和我们上面介绍几乎一样简单,我们上面的实现一共用了132行代码,而完整的实现的编译部分也只有495行Julia。而有了编译器自带的反射机制我们连计算图都不需要了,反而能实现控制流的微分。
我们回顾上面的实现,其实可以发现我们实际上是根据一个上下文(Context)来进行函数派发。而非根据函数的标签进行派发。我们上面的函数变换实际上是根据我们需要求导这个上下文,重新派发了我们原先的函数调用。Julia社区实际上为这个更加一般的派发机制,利用上面修改IR的方式实现了一个编译器扩展Cassette[8]。Cassette能够根据根据上下文这个信息来修改某段不知道出处的源代码。Cassette的测试中实际上还有一个非常简单的自动微分实现[9]。且由于Julia本身的动态编译机制,我们甚至能够通过这样的方式修改JIT编译过程。这使得我们能够做很多潜在的特定领域的编译和优化。例如:
这些特性在其它语言中可能并不会这么容易。在Python中不通过编写完整的编译器则几乎不可能。而受限于Python即便是基于XLA的JAX[14]也无法处理副作用,(有一定动态性的)控制流。而这在Julia中非常简单和直接,并且很自然就可以和整个生态中的其它package进行交互组合,与其费劲全力编写各种各样编译器不如来试试这个和Python很像的新语言吧。(手动狗头)
Julia中文社区将在8月24日在北京微软大厦办一次活动,这次talk的内容质量都很好,Zygote的作者本人也会给一个online talk:https://discourse.juliacn.com/t/topic/2044/3 欢迎来参加。(但是我去不了了哈)