热门标签 | HotTags
当前位置:  开发笔记 > 后端 > 正文

BERT的初始标准差0.02以及Warmup、LN的作用

前几天在群里大家讨论到了“Transformer如何解决梯度消失”这个问题,答案有提到残差的,也有提到LN(LayerNorm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而

前几天在群里大家讨论到了“Transformer如何解决梯度消失”这个问题,答案有提到残差的,也有提到LN(Layer Norm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而综合的问题,它其实关联到挺多模型细节,比如“BERT为什么要warmup?”、“BERT的初始化标准差为什么是0.02?”、“BERT做MLM预测之前为什么还要多加一层Dense?”,等等。本文就来集中讨论一下这些问题。


梯度消失说的是什么意思? #

在文章《也来谈谈RNN的梯度消失/爆炸问题》中,我们曾讨论过RNN的梯度消失问题。事实上,一般模型的梯度消失现象也是类似,它指的是(主要是在模型的初始阶段)越靠近输入的层梯度越小,趋于零甚至等于零,而我们主要用的是基于梯度的优化器,所以梯度消失意味着我们没有很好的信号去调整优化前面的层。

换句话说,前面的层也许几乎没有得到更新,一直保持随机初始化的状态;只有比较靠近输出的层才更新得比较好,但这些层的输入是前面没有更新好的层的输出,所以输入质量可能会很糟糕(因为经过了一个近乎随机的变换),因此哪怕后面的层更新好了,总体效果也不好。最终,我们会观察到很反直觉的现象:模型越深,效果越差,哪怕训练集都如此。

解决梯度消失的一个标准方法就是残差链接,正式提出于ResNet中。残差的思想非常简单直接:你不是担心输入的梯度会消失吗?那我直接给它补上一个梯度为常数的项不就行了?最简单地,将模型变成

这样一来,由于多了一条“直通”路x">xx,就算F(x)">F(x)F(x)中的x">xx梯度消失了,x">xx的梯度基本上也能得以保留,从而使得深层模型得到有效的训练。

 


LN真的能缓解梯度消失? #

然而,在BERT和最初的Transformer里边,使用的是Post Norm设计,它把Norm操作加在了残差之后:

其实具体的Norm方法不大重要,不管是Batch Norm还是Layer Norm,结论都类似。在文章《浅谈Transformer的初始化、参数化与标准化》中,我们已经分析过这种Norm结构,这里再来重复一下。

 

在初始化阶段,由于所有参数都是随机初始化的,所以我们可以认为x">xxF(x)">F(x)F(x)是两个相互独立的随机向量,如果假设它们各自的方差是1,那么x+F(x)">x+F(x)x+F(x)的方差就是2,而Norm">NormNorm操作负责将方差重新变为1,那么在初始化阶段,Norm">NormNorm操作就相当于“除以√2">2”:

递归下去就是

我们知道,残差有利于解决梯度消失,但是在Post Norm中,残差这条通道被严重削弱了,越靠近输入,削弱得越严重,残差“名存实亡”。所以说,在Post Norm的BERT模型中,LN不仅不能缓解梯度消失,它还是梯度消失的“元凶”之一。

 


那我们为什么还要加LN? #

那么,问题自然就来了:既然LN还加剧了梯度消失,那直接去掉它不好吗?

是可以去掉,但是前面说了,x+F(x)">x+F(x)x+F(x)的方差就是2了,残差越多方差就越大了,所以还是要加一个Norm操作,我们可以把它加到每个模块的输入,即变为x+F(Norm(x))">x+F(Norm(x))x+F(Norm(x)),最后的总输出再加个Norm">NormNorm就行,这就是Pre Norm结构,这时候每个残差分支是平权的,而不是像Post Norm那样有指数衰减趋势。当然,也有完全不加Norm的,但需要对F(x)">F(x)F(x)进行特殊的初始化,让它初始输出更接近于0,比如ReZero、Skip Init、Fixup等,这些在《浅谈Transformer的初始化、参数化与标准化》也都已经介绍过了。

但是,抛开这些改进不说,Post Norm就没有可取之处吗?难道Transformer和BERT开始就带了一个完全失败的设计?

显然不大可能。虽然Post Norm会带来一定的梯度消失问题,但其实它也有其他方面的好处。最明显的是,它稳定了前向传播的数值,并且保持了每个模块的一致性。比如BERT base,我们可以在最后一层接一个Dense来分类,也可以取第6层接一个Dense来分类;但如果你是Pre Norm的话,取出中间层之后,你需要自己接一个LN然后再接Dense,否则越靠后的层方差越大,不利于优化。

其次,梯度消失也不全是“坏处”,其实对于Finetune阶段来说,它反而是好处。在Finetune的时候,我们通常希望优先调整靠近输出层的参数,不要过度调整靠近输入层的参数,以免严重破坏预训练效果。而梯度消失意味着越靠近输入层,其结果对最终输出的影响越弱,这正好是Finetune时所希望的。所以,预训练好的Post Norm模型,往往比Pre Norm模型有更好的Finetune效果,这我们在《RealFormer:把残差转移到Attention矩阵上面去》也提到过。


我们真的担心梯度消失吗? #

其实,最关键的原因是,在当前的各种自适应优化技术下,我们已经不大担心梯度消失问题了。

这是因为,当前NLP中主流的优化器是Adam及其变种。对于Adam来说,由于包含了动量和二阶矩校正,所以近似来看,它的更新量大致上为

可以看到,分子分母是都是同量纲的,因此分式结果其实就是O(1)">O(1)O(1)的量级,而更新量就是O(η)">O(η)O(η)量级。也就是说,理论上只要梯度的绝对值大于随机误差,那么对应的参数都会有常数量级的更新量;这跟SGD不一样,SGD的更新量是正比于梯度的,只要梯度小,更新量也会很小,如果梯度过小,那么参数几乎会没被更新。

 

所以,Post Norm的残差虽然被严重削弱,但是在base、large级别的模型中,它还不至于削弱到小于随机误差的地步,因此配合Adam等优化器,它还是可以得到有效更新的,也就有可能成功训练了。当然,只是有可能,事实上越深的Post Norm模型确实越难训练,比如要仔细调节学习率和Warmup等。


Warmup是怎样起作用的? #

大家可能已经听说过,Warmup是Transformer训练的关键步骤,没有它可能不收敛,或者收敛到比较糟糕的位置。为什么会这样呢?不是说有了Adam就不怕梯度消失了吗?

要注意的是,Adam解决的是梯度消失带来的参数更新量过小问题,也就是说,不管梯度消失与否,更新量都不会过小。但对于Post Norm结构的模型来说,梯度消失依然存在,只不过它的意义变了。根据泰勒展开式:

也就是说增量f(x+Δx)−f(x)">f(x+Δx)f(x)f(x+Δx)−f(x)是正比于梯度的,换句话说,梯度衡量了输出对输入的依赖程度。如果梯度消失,那么意味着模型的输出对输入的依赖变弱了。

 

Warmup是在训练开始阶段,将学习率从0缓增到指定大小,而不是一开始从指定大小训练。如果不进行Wamrup,那么模型一开始就快速地学习,由于梯度消失,模型对越靠后的层越敏感,也就是越靠后的层学习得越快,然后后面的层是以前面的层的输出为输入的,前面的层根本就没学好,所以后面的层虽然学得快,但却是建立在糟糕的输入基础上的。

很快地,后面的层以糟糕的输入为基础到达了一个糟糕的局部最优点,此时它的学习开始放缓(因为已经到达了它认为的最优点附近),同时反向传播给前面层的梯度信号进一步变弱,这就导致了前面的层的梯度变得不准。但我们说过,Adam的更新量是常数量级的,梯度不准,但更新量依然是数量级,意味着可能就是一个常数量级的随机噪声了,于是学习方向开始不合理,前面的输出开始崩盘,导致后面的层也一并崩盘。

所以,如果Post Norm结构的模型不进行Wamrup,我们能观察到的现象往往是:loss快速收敛到一个常数附近,然后再训练一段时间,loss开始发散,直至NAN。如果进行Wamrup,那么留给模型足够多的时间进行“预热”,在这个过程中,主要是抑制了后面的层的学习速度,并且给了前面的层更多的优化时间,以促进每个层的同步优化。

这里的讨论前提是梯度消失,如果是Pre Norm之类的结果,没有明显的梯度消失现象,那么不加Warmup往往也可以成功训练。


初始标准差为什么是0.02? #

喜欢扣细节的同学会留意到,BERT默认的初始化方法是标准差为0.02的截断正态分布,在《浅谈Transformer的初始化、参数化与标准化》我们也提过,由于是截断正态分布,所以实际标准差会更小,大约是0.02/1.1368472≈0.0176">0.02/1.13684720.01760.02/1.1368472≈0.0176。这个标准差是大还是小呢?对于Xavier初始化来说,一个n×n">n×nn×n的矩阵应该用1/n">1/n1/n的方差初始化,而BERT base的n">nn为768,算出来的标准差是1/768≈0.0361">1/√7680.0361。这就意味着,这个初始化标准差是明显偏小的,大约只有常见初始化标准差的一半。

为什么BERT要用偏小的标准差初始化呢?事实上,这还是跟Post Norm设计有关,偏小的标准差会导致函数的输出整体偏小,从而使得Post Norm设计在初始化阶段更接近于恒等函数,从而更利于优化。具体来说,按照前面的假设,如果x">xx的方差是1,F(x)">F(x)F(x)的方差是σ2">σ2σ2,那么初始化阶段,Norm">NormNorm操作就相当于除以√(1+σ2">1+σ2)。如果σ">σσ比较小,那么残差中的“直路”权重就越接近于1,那么模型初始阶段就越接近一个恒等函数,就越不容易梯度消失。

正所谓“我们不怕梯度消失,但我们也不希望梯度消失”,简单地将初始化标注差设小一点,就可以使得σ">σσ变小一点,从而在保持Post Norm的同时缓解一下梯度消失,何乐而不为?那能不能设置得更小甚至全零?一般来说初始化过小会丧失多样性,缩小了模型的试错空间,也会带来负面效果。综合来看,缩小到标准的1/2,是一个比较靠谱的选择了。

当然,也确实有人喜欢挑战极限的,最近笔者也看到了一篇文章,试图让整个模型用几乎全零的初始化,还训练出了不错的效果,大家有兴趣可以读读,文章为《ZerO Initialization: Initializing Residual Networks with only Zeros and Ones》。


为什么MLM要多加Dense? #

最后,是关于BERT的MLM模型的一个细节,就是BERT在做MLM的概率预测之前,还要多接一个Dense层和LN层,这是为什么呢?不接不行吗?

之前看到过的答案大致上是觉得,越靠近输出层的,越是依赖任务的(Task-Specified),我们多接一个Dense层,希望这个Dense层是MLM-Specified的,然后下游任务微调的时候就不是MLM-Specified的,所以把它去掉。这个解释看上去有点合理,但总感觉有点玄学,毕竟Task-Specified这种东西不大好定量分析。

这里笔者给出另外一个更具体的解释,事实上它还是跟BERT用了0.02的标准差初始化直接相关。刚才我们说了,这个初始化是偏小的,如果我们不额外加Dense就乘上Embedding预测概率分布,那么得到的分布就过于均匀了(Softmax之前,每个logit都接近于0),于是模型就想着要把数值放大。现在模型有两个选择:第一,放大Embedding层的数值,但是Embedding层的更新是稀疏的,一个个放大太麻烦;第二,就是放大输入,我们知道BERT编码器最后一层是LN,LN最后有个初始化为1的gamma参数,直接将那个参数放大就好。

模型优化使用的是梯度下降,我们知道它会选择最快的路径,显然是第二个选择更快,所以模型会优先走第二条路。这就导致了一个现象:最后一个LN层的gamma值会偏大。如果预测MLM概率分布之前不加一个Dense+LN,那么BERT编码器的最后一层的LN的gamma值会偏大,导致最后一层的方差会比其他层的明显大,显然不够优雅;而多加了一个Dense+LN后,偏大的gamma就转移到了新增的LN上去了,而编码器的每一层则保持了一致性。

事实上,读者可以自己去观察一下BERT每个LN层的gamma值,就会发现确实是最后一个LN层的gamma值是会明显偏大的,这就验证了我们的猜测~


希望大家多多海涵批评斧正 #

本文试图回答了Transformer、BERT的模型优化相关的几个问题,有一些是笔者在自己的预训练工作中发现的结果,有一些则是结合自己的经验所做的直观想象。不管怎样,算是分享一个参考答案吧,如果有不当的地方,请大家海涵,也请各位批评斧正~

 

来自

https://kexue.fm/archives/8747

 


欢迎转载,转载请保留页面地址。帮助到你的请点个推荐。



推荐阅读
  • 本文详细介绍了 Spark 中的弹性分布式数据集(RDD)及其常见的操作方法,包括 union、intersection、cartesian、subtract、join、cogroup 等转换操作,以及 count、collect、reduce、take、foreach、first、saveAsTextFile 等行动操作。 ... [详细]
  • 本文将详细介绍如何在佳明手表上选择和设置原有的或自定义的表盘,帮助用户轻松完成个性化设置。 ... [详细]
  • 本文详细介绍了如何使用 Python 进行主成分分析(PCA),包括数据导入、预处理、模型训练和结果可视化等步骤。通过具体的代码示例,帮助读者理解和应用 PCA 技术。 ... [详细]
  • 本文介绍如何使用OpenCV和线性支持向量机(SVM)模型来开发一个简单的人脸识别系统,特别关注在只有一个用户数据集时的处理方法。 ... [详细]
  • 使用jqTransform插件美化表单
    jqTransform 是由 DFC Engineering 开发的一款 jQuery 插件,专用于美化表单元素,操作简便,能够美化包括输入框、单选按钮、多行文本域、下拉选择框和复选框在内的所有表单元素。 ... [详细]
  • Vision Transformer (ViT) 和 DETR 深度解析
    本文详细介绍了 Vision Transformer (ViT) 和 DETR 的工作原理,并提供了相关的代码实现和参考资料。通过观看教学视频和阅读博客,对 ViT 的全流程进行了详细的笔记整理,包括代码详解和关键概念的解释。 ... [详细]
  • 单片微机原理P3:80C51外部拓展系统
      外部拓展其实是个相对来说很好玩的章节,可以真正开始用单片机写程序了,比较重要的是外部存储器拓展,81C55拓展,矩阵键盘,动态显示,DAC和ADC。0.IO接口电路概念与存 ... [详细]
  • OpenAI首席执行官Sam Altman展望:人工智能的未来发展方向与挑战
    OpenAI首席执行官Sam Altman展望:人工智能的未来发展方向与挑战 ... [详细]
  • 通过手机获取的GPS坐标在手机地图上存在约100-200米的偏差,但在Google Maps中搜索同一坐标时,定位非常精确。这种偏差可能出于安全或隐私考虑而被有意引入。此外,不同设备和环境下的GPS信号强度和精度也会影响最终的定位结果。 ... [详细]
  • 脑机接口技术在物联网行业中的应用与前景分析
    近期,国际研究人员开发了一种轻便的脑电图(EEG)采集与信号处理系统,并在物联网领域进行了初步应用研究。该系统配备了8个可扩展的采集电极和1个参考电极,具备高灵敏度的放大功能,能够有效捕捉和处理脑电信号。通过与物联网技术的结合,该系统有望在智能家居、健康监测和人机交互等领域发挥重要作用,展现出广阔的应用前景。 ... [详细]
  • 本指南从零开始介绍Scala编程语言的基础知识,重点讲解了Scala解释器REPL(读取-求值-打印-循环)的使用方法。REPL是Scala开发中的重要工具,能够帮助初学者快速理解和实践Scala的基本语法和特性。通过详细的示例和练习,读者将能够熟练掌握Scala的基础概念和编程技巧。 ... [详细]
  • 针对图像分类任务的训练方案进行了优化设计。通过引入PyTorch等深度学习框架,利用其丰富的工具包和模块,如 `torch.nn` 和 `torch.nn.functional`,提升了模型的训练效率和分类准确性。优化方案包括数据预处理、模型架构选择和损失函数的设计等方面,旨在提高图像分类任务的整体性能。 ... [详细]
  • 每日前端实战:148# 视频教程展示纯 CSS 实现按钮两侧滑入装饰元素的悬停效果
    通过点击页面右侧的“预览”按钮,您可以直接在当前页面查看效果,或点击链接进入全屏预览模式。该视频教程展示了如何使用纯 CSS 实现按钮两侧滑入装饰元素的悬停效果。视频内容具有互动性,观众可以实时调整代码并观察变化。访问以下链接体验完整效果:https://codepen.io/comehope/pen/yRyOZr。 ... [详细]
  • 本文介绍了如何在iOS平台上使用GLSL着色器将YV12格式的视频帧数据转换为RGB格式,并展示了转换后的图像效果。通过详细的技术实现步骤和代码示例,读者可以轻松掌握这一过程,适用于需要进行视频处理的应用开发。 ... [详细]
  • 本文深入探讨了HTTP头部中的Expires与Cache-Control字段及其缓存机制。Cache-Control字段主要用于控制HTTP缓存行为,其在HTTP/1.1中得到了广泛应用,而HTTP/1.0中主要使用Pragma:no-cache来实现类似功能。Expires字段则定义了资源的过期时间,帮助浏览器决定是否从缓存中读取资源。文章详细解析了这两个字段的具体用法、相互关系以及在不同场景下的应用效果,为开发者提供了全面的缓存管理指南。 ... [详细]
author-avatar
手机用户2502862657
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有