热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

表达式最后一天_一天实现你自己的源到源自动微分

上一篇文章我讲了如何利用运算符重载(operatoroverloading)实现自动微分。罗秀哲:一天实现自己的自动微分​zhuanlan

上一篇文章我讲了如何利用运算符重载(operator overloading)实现自动微分。

罗秀哲:一天实现自己的自动微分​zhuanlan.zhihu.com

但是如果你有使用过诸如PyTorch,Flux/Tracker,AutoGrad这一类基于运算符重载的自动微分库,就会发现,这些库有两个通病:

  • 只能使用框架所提供的函数和矩阵/张量类型,如果想要对一般的程序进行求导就不行了
  • 无法处理控制流,因为简单的运算符重载无法记录下来控制流

在Yan LeCun等人的号召下,想要实现可微分编程(Differentiable Programming)如果没有上面这两个功能可不行。不能对控制流进行微分叫什么可微分编程?因为我们希望我们编写的任意在数学上成立的程序都可以进行自动微分。此外上篇文章的评论区里有人提到了TensorFlow的自动产生符号导数然后加入计算图的功能。而实际上如果对编译器数学的同学会知道这类符号计算就是一个简单的编译步骤。

这些问题在源对源(source to source)的自动微分下实现起来将非常自然,这篇文章将实现一个不带控制流的简单版本,而完整的版本已经在Julia中通过staged programming的方式实现了,这个包叫做 Zygote.jl [1]。它中文翻译很搞笑,叫卵子。使用这个包的人都用了个卵子自动微分。我在今年JuliaCon的周五Hackthon期间在Zygote的作者Mike的帮助下实现了一个简单版本的。

如果你想要阅读更详细的代码的,将这篇文章实现的简单版本放在了这个repo里,

Roger-luo/YASSAD.jl​github.com
1d8906634c59d2bae54f4576516ce19e.png

英文版在我的blog:

http://blog.rogerluo.me/2019/07/27/yassad/​blog.rogerluo.me

不过,在我们开始编写程序前,让我们来回顾一些基础知识。

Julia语言的编译过程

首先让我们来简单了解一下Julia语言是如何进行编译的。

  1. 首先,所有的代码本质上都是一些字符串(string),存储在硬盘上的文本文件中
  2. 我们首先要解析(parse)这些字符串,得到一个抽象语法数(Abstract Syntax Tree,AST)
  3. 而 AST 里有一些节点是宏,这些宏是一些只接受编译时期变量的,里面描述了如何产生更多的 AST,在这一步将会运行这些宏,我们成为 展开AST。你可以通过 macroexpand 宏查看这一步的编译结果
  4. 这个时候我们再将AST里的语法糖等节点全部替换为函数调用,并且使用SSA(Static Single Assignment)形式的IR作为更低级的表示。什么是SSA IR?我们将在后面介绍

到此位置我们完成了代码的初始化过程。

  1. 然后我们的函数会在被派发的时候才会被继续编译,这是因为对于一般的函数(generic function)我们是无法在编译时期就确定这个函数的变量类型的,从而无法产生定制的机器码。例如对机器来说 Int 和 Float64 即便都是加法,对整数来说调用的可能是 leaq 指令,而浮点数则可能是 vaddsd
  2. 然后我们开始进行类型推导(type inference),这是为什么你可以不用写清楚到底是什么类型的原因。同时有了类型以后编译器才能做很多优化。这样我们就得到了带类型的IR(typed IR),你可以用 code_typed 来查看这部分编译结果
  3. 然后我们用这个IR来产生LLVM IR
  4. LLVM IR会用来产生机器码,你可以用 code_native 宏来获得这个编译结果

为了描述Julia是如何编译的,我从之前的JuliaCon的报告里拿出来一副图。

6ac308e0604b168964ce81c6d34d3578.png

这张图更清楚一些,你可以看到与静态语言不同的是,每次函数调用都会经过编译过程。Julia中的编译(包括JIT编译)是以函数调用为界的。

SSA格式的中间表示

完整的介绍SSA需要很大的篇幅,足够写一本书了[2]。 但是不用担心,我们这里需要用到的部分很简单。你只需要知道下面几个概念就可以了

  • 所有的变量都有且仅有一次赋值(有时候也会说这个是线性的)
  • 大部分变量的值都来自于某个函数的调用(function call)
  • 控制流都变成了分支语句(branching)

如果你已经阅读过我上篇文章,我相信你已经了解计算图这个概念,但是现在我们要重新思考什么是计算图。我们回顾一下上一节用到的图

72b5c420a1183507a748850dd90b5198.png

在进行AD的计算的时候,我们将计算过程表示为一个计算图。每个节点都会使用一个运算符(operator)然后获得一个中间值(一个节点),接下来这个节点的值会和函数一起暂时存起来,在后向传播的时候使用。也就是说每个节点的中间变量都只会被赋值一次,否则就不能唯一对应一个算符。而每个节点有两个函数,一个是代表前向计算的函数,另外一个则是代表后向传播的导数函数。

很自然的,我想你已经发现了,这就是一个SSA的格式。而所谓的自动微分,其实就是我们正常的前向程序的某一个对偶的程序(也就是一个对偶的函数)。而实际上,没有控制流的计算图我们称之为 Wengert Lists [3]。大部分基于运算符重载的自动微分实际上都是实现了这个算法,有时候它也被称为Tape。而SSA格式则更加一般,它包含了控制流。所以我们可以通过对SSA格式来计算自动微分来实现对控制流的自动微分。这也是Zygote第一篇文章所提到的方法[4]

而由于后向传播的函数只是前向传播的函数的伴随(adjoint,实际上一些数学家也认为后向传播可以定义在一个对偶空间上,所以我们不妨就使用这个称谓)。我们不妨直接将这个函数写做一个前向传播函数+一个闭包的格式。

function forward(::typeof(your_function), xs...)# 函数声明output = # 函数输出output, function (Δ)# 这是一个闭包(或者你可以理解成一个能够获取forward函数的local变量的匿名函数)end
end

实现成闭包的好处是实现一些需要使用前向传播的中间值的导数的时候我们可以把这些中间值以闭包函数的状态(state)的方式托管给编译器,而不需要像我在上篇文章里一样,手动将其存在一个Node对象中。我们称这个返回的闭包函数为pullback。

所以假如我们想要获得一个下面这个函数的导数

foo(x) = bar(baz(x))

如果手动做这件事情,我们只需要定义一个 forward 函数

function forward(::typeof(foo), x)x1, back1 = forward(baz, x)x2, back2 = forward(bar, x1)return x2, function (Δ)dx1 = back2(Δ)dx2 = back1(dx1)return dx2end
end

实际上,一般的来说,一段没有控制流的程序的伴随,就是倒着把这段程序中的所有函数调用换成伴随程序的调用,变量换成其对应的伴随变量(adjoint variable)。但是我们如何通过一个函数定义来产生这个forward函数呢?有人可能会说宏,但是宏会要求我们在所有可以进行求导的函数前面都要这样标记,这是我们不希望的,我们希望未来使用这个自动微分的时候我们不需要写任何额外的东西。

而由于SSA格式的IR已经将所有的语法糖,函数都转换成低级表示了,这也就意味着仅仅需要定义一些原始表示,我们就可以用上面这个规则(实际上就是链式法则)组合出非常多的导数,而这些导数的生成都发生在编译时期,所以不会反复占用运行时间,并且这也能帮助我们未来进一步进行一些有针对性的优化。

所以我们想在SSA IR上来做这件事,但是怎么做呢?我们知道宏可以用来修改代码的解析过程(parsing),而Julia里还有另外一个元素用来实现对类型推导和IR的修改,生成函数(generated function)。生成函数可以通过一个宏来声明

@generated function foo(a, b, c)return :(1 + 1)
end

它看起来像是一个普通函数,但是注意它是发生在类型推导期间的

fcd54d3e6f2e5322efdb9e096aa531d1.png

所以你只能知道函数变量 a, b, c的类型信息,我们可以通过类型信息产生两种格式的代码,一种是AST表达式,在Julia里叫Expr,另外一种就是我们的SSA IR,是一种叫CodeInfo的Julia对象。IRTools[5]里提供了操作SSA IR的一些工具,我们将使用这个包来编写产生这个forward函数的代码。

我们可以通过 code_ir 宏来拿到函数的ir对象,这个对象是被IRTools处理过的,它的类型是IR。和 code_typed 宏或者 code_lowered 宏得到的对象不同的是,IR类型实现了一些方便的函数操作,并且IR类型中不会保存变量的名称,所有的变量都用 %数字 来表示

julia> @code_ir foo(1.0)
1: (%1, %2)%3 = (Main.baz)(%2)%4 = (Main.bar)(%3)return %4

注意,你会发现,即便这里 baz和bar这两个函数没定义也不会报错,这是因为Julia本质上还是一个动态语言,所以只要在真正运行的时候才会报错。

这个格式下,每一行代码都绑定了一个变量,等号右边我们称之为声明(statement),左边是变量(variable)。你可以用类似字典的接口来使用这个对象,例如

julia> using IRTools: varjulia> ir[var(3)]
IRTools.Statement(:((Main.baz)(%2)), Any, 1)

它会给你一个声明对象,这个对象里记录了这段声明的表达式,这个宏给你的是没有经过类型推导的IR。所以后面的Any就是这个变量的类型。Any也是Julia中唯一的静态类型。为了简单起见,我们这里不介绍带类型的IR(因为原理上是类似的但是实现细节有一些不同)。最后数字1是指这段声明所在的行号。

前面的1是什么意思呢?在SSA格式中我们用这样的代码块来表示分支,我们不妨写一个ifelse语句看看

julia> function foo(x)if x > 1bar(x)elsebaz(x)endend
foo (generic function with 1 method)julia> @code_ir foo(1.0)
1: (%1, %2)%3 = %2 > 1br 3 unless %3
2:%4 = (Main.bar)(%2)return %4
3:%5 = (Main.baz)(%2)return %5

ifelse在低级表示中是通过branch语句表示的,实际上循环也是类似的。Julia里的循环只是对iterate函数的语法糖而已。所以我们只要能够对br语句进行微分,我们就可以对控制流微分了。

julia> function foo(x)for x in 1:10bar(x)endbaz(x)end
foo (generic function with 1 method)julia> @code_ir foo(1.0)
1: (%1, %2)%3 = 1:10%4 = (Base.iterate)(%3)%5 = %4 === nothing%6 = (Base.not_int)(%5)br 3 unless %6br 2 (%4)
2: (%7)%8 = (Core.getfield)(%7, 1)%9 = (Core.getfield)(%7, 2)%10 = (Main.bar)(%8)%11 = (Base.iterate)(%3, %9)%12 = %11 === nothing%13 = (Base.not_int)(%12)br 3 unless %13br 2 (%11)
3:%14 = (Main.baz)(%2)return %14

那么这个IR是怎么获得的呢?为了获得IR,我们首先需要知道这个通用函数(generic function)被派发了哪个方法(method),在Julia里每个通用函数都有一个方法表(method table)你可以通过这个函数的类型标签来获得具体的方法。例如这个foo函数,每次调用 foo(1.0) 的时候,Julia都会产生下面的标签

Tuple{typeof(foo), Float64}

然后查找这个类型标签对应的方法,然后进行编译。IRTools里提供了一个meta函数来查找这些信息。

julia> T = Tuple{typeof(foo), Float64}
Tuple{typeof(foo),Float64}julia> m = IRTools.meta(T)
Metadata for foo(x) in Main at REPL[2]:1

meta中存储了比较详细的IR信息,函数变量信息等。我们可以通过这个meta信息构建一个可以操作的IR对象。

julia> IRTools.IR(m)
1: (%1, %2)%3 = (Main.baz)(%2)%4 = (Main.bar)(%3)return %4

例如我们可以向这个IR里插入一个Expr,就好像是操作一个列表一样

julia> push!(ir, :(1+1))
%5julia> ir
1: (%1, %2)%3 = (Main.baz)(%2)%4 = (Main.bar)(%3)%5 = 1 + 1return %4

IRTools会自动为你增加变量名。同理我们可以使用insert!操作在第四个变量的前面插入一行

julia> using IRTools: varjulia> insert!(ir, var(4), :(1+1))
%5julia> ir
1: (%1, %2)%3 = (Main.baz)(%2)%5 = 1 + 1%4 = (Main.bar)(%3)return %4

或者我们可以在第四个变量后面插入另外一行

julia> using IRTools: insertafter!julia> insertafter!(ir, var(4), :(2+2))
%6julia> ir
1: (%1, %2)%3 = (Main.baz)(%2)%5 = 1 + 1%4 = (Main.bar)(%3)%6 = 2 + 2return %4

有了这些工具我们基本上可以来做前向传播的IR变换了。我们目标是把每个函数调用换成对forward函数的调用,然后把forward函数返回的pullback收集起来产生一个闭包。但是等等,我好像没有讲在Julia的SSA IR里闭包是什么?我们一会儿再说这件事。我们先看看要如何把函数调用换成forward。

我们先拿出来一个声明看看它长的什么样

julia> dump(ir[var(3)])
IRTools.Statementexpr: Exprhead: Symbol callargs: Array{Any}((2,))1: GlobalRefmod: Module Mainname: Symbol baz2: IRTools.Variableid: Int64 2type: Anyline: Int64 1

那么实际上我们只要检查这个声明的表达式里的head成员是否是call就行了。但是我们不想改变我们原来的IR,IRTools里提供了一个Pipe对象用来辅助进行IR变换,这样我们可以直接往这个对象里进行插入,它会自动将这些插入操作放进一个类似的IR对象里。转换的IR存储在to成员中。Pipe一开始只会构建一个拥有同样变量的代码块

julia> IRTools.Pipe(ir).to
1: (%1, %2)

我们将这个函数命名为register,它的功能类似于我们上次实现的register函数,只是这一次不需要你手动来通过一个特别的类型派发过来了。注意我们这里在演示的时候为了方便在REPL里测试,所有的forward函数都定义在REPL里,也就是名为Main的module里如果你将这些代码写进了自己的module那么你需要修改这里的Main。

function register(ir)pr = Pipe(ir)argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)endendfinish(pr)
end

解释一下上面的代码,首先因为我们不再是调用原本的函数了,新的函数调用变成了多了一个变量的forward

forward(f, args...)

所以我们的IR也需要在最前面增加一个变量。

然后接下来我们遍历所有的SSA变量和其对应的声明(statement),如果声明的标签是一个 :call 的话说明这是一个函数调用,我们把它换成一个 forward调用。但是同时我们要保留原来代码的行号,否则报错信息会不准确。然后我们把这个声明插入到原来的变量的位置上。

但是由于forward会返回两个变量,一个是我们前向传播的值,另外一个是后向传播之后要用的pullback,所以我们需要对返回的这个tuple调用一次getindex。(因为在SSA里没法写 x, y = forward...,想想为什么?) 。但是注意getindex是不能在编译时期调用的,我们要插入一个调用getindex的表达式。这个被实现成一个 xgetindex 函数了

xgetindex(x, i...) = xcall(Base, :getindex, x, i...)

然后我们看看这个函数把IR变成什么了

julia> register(ir)
1: (%3, %1, %2)%4 = (Main.forward)(Main.baz, %2)%5 = (Base.getindex)(%4, 1)%6 = (Main.forward)(Main.bar, %5)%7 = (Base.getindex)(%6, 1)return %7

不错。我们成功把这个函数里的调用修改成了对forward函数的调用。

回过头来,我们考虑闭包的问题。是的在低级表示里,是没有闭包的。但是我们可以用一个类型把它存起来,然后让这个类型变成一个callable。这个类型会同时在类型参数S里存下函数的标签(signature),这样我们在调用pullback的时候可以查到函数的IR。

struct Pullback{S, T}data::T
endPullback{S}(data::T) where {S, T} = Pullback{S, T}(data)

这里 data 里将会存一个Tuple,里面是所有的forward返回的pullback,它们会按照调用顺序存储。

为了将pullback都存起来,我们首先要修改上面的实现,把pullback从forward函数的返回值里取出来,然后把他们都存进一个tuple然后再存到一个Pullback对象里去。因为Pullback需要函数的标签,所以我们要增加一个变量用来输入函数标签。

function register(ir, F)pr = Pipe(ir)pbs = Variable[]argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))push!(pbs, substitute(pr, J))endendpr = finish(pr)v = push!(pr, xtuple(pbs...))pbv = push!(pr, Expr(:call, Pullback{F}, v))return pr
end

这里xtuple类似于xgetindex是一个用来产生调用tuple构造函数表达式的函数

xtuple(xs...) = xcall(Core, :tuple, xs...)

最后让我们把pullback和原本的返回值打包成tuple返回

function register(ir, F)pr = Pipe(ir)pbs = Variable[]argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))push!(pbs, substitute(pr, J))endendpr = finish(pr)v = push!(pr, xtuple(pbs...))pbv = push!(pr, Expr(:call, Pullback{F}, v))ret = pr.blocks[end].branches[end].args[1]ret = push!(pr, xtuple(ret, pbv))pr.blocks[end].branches[end].args[1] = retreturn pr
end

这里return实际上也是一个分支语句,没有分支的函数return是最后一个(也是唯一一个)代码块的最后一个分支的第一个变量,我们修改它为我们要返回的tuple。

好了现在让我们来看看register会把代码变成什么

julia> register(ir, Tuple{typeof(foo), Float64})
1: (%3, %1, %2)%4 = (Main.forward)(Main.baz, %2)%5 = (Base.getindex)(%4, 1)%6 = (Base.getindex)(%4, 2)%7 = (Main.forward)(Main.bar, %5)%8 = (Base.getindex)(%7, 1)%9 = (Base.getindex)(%7, 2)%10 = (Core.tuple)(%9, %6)%11 = (Pullback{Tuple{typeof(foo),Float64},T} where T)(%10)%12 = (Core.tuple)(%8, %11)return %12

接下来让我们来实现forward函数

@generated function forward(f, xs...)# ....
end

注意在generated function里,函数变量的值都是他们的类型。我们首先产生一个函数标签来获取函数的meta

@generated function forward(f, xs...)T = Tuple{f, xs...}m = IRTools.meta(T)m === nothing && return
end

如果meta获取不到返回了nothing说明这个函数标签不存在,也就是说这个method没有定义。我们直接返回。

然后我们用我们的register函数产生这个forward函数的定义

@generated function forward(f, xs...)T = Tuple{f, xs...}m = IRTools.meta(T)m === nothing && returnfrw = register(IR(m), T)
end

但是frw的类型是IR,它并不是真正的Julia中间表示,它只是IRTools创建的一个方便操作的表示,它没有保存函数的变量名称。让我们来手动给他们加上

@generated function forward(f, xs...)T = Tuple{f, xs...}m = IRTools.meta(T)m === nothing && returnfrw = register(IR(m), T)argnames!(m, Symbol("#self#"), :f, :xs)frw = varargs!(m, frw, 2)return IRTools.update!(m, frw)
end

然后我们将forward这个函数的第二个变量xs标记为可变参数(var arg),最后调用update!使用新的meta来产生forward函数的IR。我们不妨来看看这个forward函数的IR是什么样子的

julia> @code_ir forward(foo, 1.0)
1: (%1, %2, %3)%4 = (Base.getfield)(%3, 1)%5 = (Main.forward)(Main.baz, %4)%6 = (Base.getindex)(%5, 1)%7 = (Base.getindex)(%5, 2)%8 = (Main.forward)(Main.bar, %6)%9 = (Base.getindex)(%8, 1)%10 = (Base.getindex)(%8, 2)%11 = (Core.tuple)(%10, %7)%12 = (Main.Pullback{Tuple{typeof(foo),Float64},T} where T)(%11)%13 = (Core.tuple)(%9, %12)return %13

运行一下

julia> forward(foo, 1.0)
ERROR: MethodError: no method matching getindex(::Nothing, ::Int64)
Stacktrace:[1] * at ./float.jl:399 [inlined][2] forward(::typeof(*), ::Float64, ::Float64) at /Users/roger/.julia/dev/YASSAD/src/compiler.jl:0[3] baz at ./REPL[4]:1 [inlined][4] forward(::typeof(baz), ::Float64) at /Users/roger/.julia/dev/YASSAD/src/compiler.jl:0[5] foo at ./REPL[2]:1 [inlined][6] forward(::typeof(foo), ::Float64) at /Users/roger/.julia/dev/YASSAD/src/compiler.jl:0[7] top-level scope at none:0

我们发现它会递归地调用这个forward函数,但是我们并没有为primitive地方法定义导数规则,所以它会在递归到某一步地时候报错。我们可以通过定义一些forward函数作为primitive,比如*

julia> forward(::typeof(*), a::Real, b::Real) = a * b, Δ->(Δ*b, a*Δ)julia> forward(foo, 1.0)
(1.0, YASSAD.Pullback{.....}

我们就会得到一个非常长地Pullback,但是我们还不能调用它,因为我们还没有定义Pullback是如何被调用的。

那么我们成功的将前向求值的过程变成了某种可以追踪的格式,我们接下来来产生Pullback的IR。首先类似于forward我们可以这样定义Pullback

@generated function (::Pullback{S})(delta) where Sm = IRTools.meta(S)m === nothing && returnir = IR(m)_, pbs = register(ir, S)back = adjoint(ir, pbs)argnames!(m, Symbol("#self#"), :delta)return IRTools.update!(m, back)
end

因为后向传播是分开调用的,这个时候我们已经没法直接拿走上面我们处理过的前向传播的IR了,所以我们要再次调用一遍register,但是不用担心,编译只会发生一次,所以它并不会占用我们真正的运行时间。在产生我们前向传播IR的伴随程序的时候我们还需要知道都有哪些变量产生了pullback,所以我们需要一个词典(Dict)来记录这个对应关系。我们把前面register定义修改一下

function register(ir, F)pr = Pipe(ir)pbs = Dict{Variable, Variable}()argument!(pr, at = 1)for (v, st) in prex = st.exprif Meta.isexpr(ex, :call)yJ = insert!(pr, v, stmt(xcall(Main, :forward, ex.args...), line = ir[v].line))pr[v] = xgetindex(yJ, 1)J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line))pbs[v] = substitute(pr, J)endendpr = finish(pr)v = push!(pr, xtuple(values(pbs)...))pbv = push!(pr, Expr(:call, Pullback{F}, v))ret = pr.blocks[end].branches[end].args[1]ret = push!(pr, xtuple(ret, pbv))pr.blocks[end].branches[end].args[1] = retreturn pr, pbs
end

因为伴随程序和原来的IR是相反的顺序,所以我们不能用Pipe了,我们需要创建一个空的IR对象。然后手动给它增加两个变量,一个是Pullback这个对象自己,另外一个是后向传播的梯度。

adj = empty(ir)
self = argument!(adj)
delta = argument!(adj)

首先让我们从我们的对象里取出所有的pullback,这里getfield函数是语法糖 . 的低级函数调用,这句话相当于 self.data

pullbacks = pushfirst!(adj, xcall(:getfield, self, QuoteNode(:data)))

然后我们倒着遍历一遍IR中的所有变量

vars = keys(ir)for k in length(vars):-1:1v = vars[k]ex = ir[v].exprif haskey(pbs, v)pbv = insertafter!(adj, pullbacks, xcall(:getindex, pullbacks, k))g = push!(adj, Expr(:call, pbv, v))endend

如果这个变量在我们的pullback词典里的话,我们就在最前面取出来,命名这个SSA变量为pbv,然后调用它。但是这里有一个问题,如果一个变量有多个梯度(它可能参与了几个函数调用,所以我们会获得几个梯度)我们需要将这几个梯度累加到一起。所以我们需要把这些梯度按照变量记下来

grads = Dict()

然后我们实现一个grad函数,当输入两个变量的时候后面的变量是梯度变量,我们将梯度变量记录在这个变量的列表里,然后返回这个梯度变量

grad(x, x̄) = push!(get!(grads, x, []), x̄)

然后当我们输入一个变量的时候我们返回一个将这些变量累加之后的SSA变量

grad(x) = xaccum(adj, get(grads, x, [])...)

这里xaccum和前面一样是一个用来产生调用accum函数的方法。但是Julia里自带的累加accum函数和我们需要的并不一样(它定义在数组而非变量上),我们来自己定义一个。

xaccum(ir) = nothing
xaccum(ir, x) = x
xaccum(ir, xs...) = push!(ir, xcall(YASSAD, :accum, xs...))
accum() = nothing
accum(x) = x
accum(x, y) =x == nothing ? y :y == nothing ? x :x + yaccum(x, y, zs...) = accum(accum(x, y), zs...)accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)

最后我们的pullback将会返回我们每个前馈函数的变量所对应的梯度变量。而如果我们的变量没有梯度。例如forward中的第一个函数变量,在大部分情况下是没有梯度的,但是如果这个函数也是一个闭包函数,或者是一个callable就可以有梯度。所以我们现在约定每个pullback返回的梯度第一个是forward中第一个变量的梯度,它可以是nothing,pullback总是返回和forward函数变量同样数量的梯度。

所以最后我们的adjoint函数就是

function adjoint(ir, pbs)adj = empty(ir)self = argument!(adj)delta = argument!(adj)pullbacks = pushfirst!(adj, xcall(:getfield, self, QuoteNode(:data)))grads = Dict()grad(x, x̄) = push!(get!(grads, x, []), x̄)grad(x) = xaccum(adj, get(grads, x, [])...)grad(last(keys(ir)), delta)vars = keys(ir)for k in length(vars):-1:1v = vars[k]ex = ir[v].exprif haskey(pbs, v)pbv = insertafter!(adj, pullbacks, xcall(:getindex, pullbacks, k))g = push!(adj, Expr(:call, pbv, grad(v)))for (i, x) in enumerate(ex.args)x isa Variable || continuegrad(x, push!(adj, xgetindex(g, i)))endendendgs = [grad(x) for x in arguments(ir)]Δ = push!(adj, xtuple(gs...))return!(adj, Δ)return adj
end

和上篇文章一样,让我们用矩阵乘法+矩阵的迹来试试效果!我们可以直接使用Julia自带的类型了!

using LinearAlgebrafunction forward(::typeof(*), A::Matrix, B::Matrix)A * B, function (Δ::Matrix)Base.@_inline_meta(nothing, Δ * B', A' * Δ)end
endfunction forward(::typeof(tr), A::Matrix)tr(A), function (Δ::Real)Base.@_inline_meta(nothing, Δ * Matrix(I, size(A)))end
endjulia> using LinearAlgebra, BenchmarkToolsjulia> mul_tr(A::Matrix, B::Matrix) = tr(A * B)
mul_tr (generic function with 1 method)julia> A, B = rand(30, 30), rand(30, 30);julia> mul_tr(A, B)
216.7247235502547julia> z, back = forward(mul_tr, A, B);julia> julia> back(1);

而性能和我们手动实现的梯度也是基本一样的 (它其实和我们手动写出来的对于机器来说是一样的)。这是我们上次手动实现的梯度

julia> @benchmark bench_tr_mul_base($(rand(30, 30)), $(rand(30, 30)))
BenchmarkTools.Trial: memory estimate: 28.78 KiBallocs estimate: 5--------------minimum time: 10.696 μs (0.00% GC)median time: 13.204 μs (0.00% GC)mean time: 24.075 μs (43.31% GC)maximum time: 62.964 ms (99.97% GC)--------------samples: 10000evals/sample: 1

这是我们自动产生的梯度

julia> @benchmark tr_mul($A, $B)
BenchmarkTools.Trial: memory estimate: 36.17 KiBallocs estimate: 14--------------minimum time: 12.921 μs (0.00% GC)median time: 15.659 μs (0.00% GC)mean time: 27.304 μs (40.97% GC)maximum time: 60.141 ms (99.94% GC)--------------samples: 10000evals/sample: 1

到此为止我们已经实现了一个非常简单的源对源自动微分,虽然这个自动微分库已经可以使用大部分微分规则。但是在上面的IR变换过程里我们没有处理控制流。而更完整的实现Zygote可以对几乎一切Julia对象进行求导,这包括:自定义类型,控制流,外部函数调用(比如Zygote可以使用PyTorch定义的函数,所以兼容你的旧PyTorch模型),in-place函数。而例子甚至包括:部分我们的量子算法框架Yao[6]定义的量子线路,ODE求解器(NeuralODE很需要这个),光线追踪器(Ray Tracer)。详见Zygote今年NeurIPS的论文[7]。而Zygote本身的实现和我们上面介绍几乎一样简单,我们上面的实现一共用了132行代码,而完整的实现的编译部分也只有495行Julia。而有了编译器自带的反射机制我们连计算图都不需要了,反而能实现控制流的微分。

基于上下文派发

我们回顾上面的实现,其实可以发现我们实际上是根据一个上下文(Context)来进行函数派发。而非根据函数的标签进行派发。我们上面的函数变换实际上是根据我们需要求导这个上下文,重新派发了我们原先的函数调用。Julia社区实际上为这个更加一般的派发机制,利用上面修改IR的方式实现了一个编译器扩展Cassette[8]。Cassette能够根据根据上下文这个信息来修改某段不知道出处的源代码。Cassette的测试中实际上还有一个非常简单的自动微分实现[9]。且由于Julia本身的动态编译机制,我们甚至能够通过这样的方式修改JIT编译过程。这使得我们能够做很多潜在的特定领域的编译和优化。例如:

  • 自动寻找矩阵的稀疏性[10]
  • 自动进行SPMD转换[11]
  • 中间变量优化
  • 制作debugger[12]
  • 让异构计算的界面更加统一[13]

这些特性在其它语言中可能并不会这么容易。在Python中不通过编写完整的编译器则几乎不可能。而受限于Python即便是基于XLA的JAX[14]也无法处理副作用,(有一定动态性的)控制流。而这在Julia中非常简单和直接,并且很自然就可以和整个生态中的其它package进行交互组合,与其费劲全力编写各种各样编译器不如来试试这个和Python很像的新语言吧。(手动狗头)

Julia中文社区将在8月24日在北京微软大厦办一次活动,这次talk的内容质量都很好,Zygote的作者本人也会给一个online talk:https://discourse.juliacn.com/t/topic/2044/3 欢迎来参加。(但是我去不了了哈)

参考

  1. ^Zygote: 21 Century AD https://github.com/FluxML/Zygote.jl
  2. ^Static Single Assignment Book http://ssabook.gforge.inria.fr/latest/book.pdf
  3. ^Wengert Lists https://www.sciencedirect.com/science/article/pii/S0377042700004222
  4. ^Zygote文章:Don't unroll adjoints https://arxiv.org/pdf/1810.07951.pdf
  5. ^IRTools.jl https://github.com/MikeInnes/IRTools.jl
  6. ^Yao.jl: Extensible, Efficient Quantum Algorithm Design for Humans. https://github.com/QuantumBFS/Yao.jl
  7. ^A Differentiable Programming System to Bridge Machine Learning and Scientific Computing https://arxiv.org/abs/1907.07587
  8. ^Cassette: Overdub Your Julia Code https://github.com/jrevels/Cassette.jl
  9. ^Cassette AD https://github.com/jrevels/Cassette.jl/blob/a67c8e98ea975203e46b913807a86de5d3e84130/test/misctaggingtests.jl#L402
  10. ^SparsityDetection https://github.com/JuliaDiffEq/SparsityDetection.jl
  11. ^Hydra https://github.com/FluxML/Hydra.jl
  12. ^MagneticReadHead https://github.com/oxinabox/MagneticReadHead.jl
  13. ^CUDAnative https://github.com/JuliaGPU/CUDAnative.jl/pull/334
  14. ^google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more https://github.com/google/jax



推荐阅读
  • Vue基础一、什么是Vue1.1概念Vue(读音vjuː,类似于view)是一套用于构建用户界面的渐进式JavaScript框架,与其它大型框架不 ... [详细]
  • [转载]从零开始学习OpenGL ES之四 – 光效
    继续我们的iPhoneOpenGLES之旅,我们将讨论光效。目前,我们没有加入任何光效。幸运的是,OpenGL在没有设置光效的情况下仍然可 ... [详细]
  • 一、Hadoop来历Hadoop的思想来源于Google在做搜索引擎的时候出现一个很大的问题就是这么多网页我如何才能以最快的速度来搜索到,由于这个问题Google发明 ... [详细]
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • 本文介绍了数据库的存储结构及其重要性,强调了关系数据库范例中将逻辑存储与物理存储分开的必要性。通过逻辑结构和物理结构的分离,可以实现对物理存储的重新组织和数据库的迁移,而应用程序不会察觉到任何更改。文章还展示了Oracle数据库的逻辑结构和物理结构,并介绍了表空间的概念和作用。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • MACElasticsearch安装步骤及验证方法
    本文介绍了MACElasticsearch的安装步骤,包括下载ZIP文件、解压到安装目录、启动服务,并提供了验证启动是否成功的方法。同时,还介绍了安装elasticsearch-head插件的方法,以便于进行查询操作。 ... [详细]
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • HDFS2.x新特性
    一、集群间数据拷贝scp实现两个远程主机之间的文件复制scp-rhello.txtroothadoop103:useratguiguhello.txt推pushscp-rr ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • This article discusses the efficiency of using char str[] and char *str and whether there is any reason to prefer one over the other. It explains the difference between the two and provides an example to illustrate their usage. ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • Introduction(简介)Forbeingapowerfulobject-orientedprogramminglanguage,Cisuseda ... [详细]
  • Answer:Theterm“backslash”isonofthemostincorrectlyusedtermsincomputing.People ... [详细]
  • 学习一门编程语言,除了语法,最重要的是学习解决问题。很多时候单凭自己的能力确实无法做到完美解决,所以无论是搜索引擎、社区、文档还是博客&# ... [详细]
author-avatar
小老特
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有