热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

ICLR2020满分论文|为什么梯度裁剪能加速模型训练?

一只小狐狸带你解锁炼丹术&NLP秘籍作者:苏剑林(来自追一科技,人称“苏神”)前言需要许多时间步计算的循环神经网络ÿ

一只小狐狸带你解锁 炼丹术&NLP 秘籍

作者:苏剑林(来自追一科技,人称“苏神”)

前言

需要许多时间步计算的循环神经网络,如LSTM、GRU,往往存在梯度爆炸的问题。其目标函数可能存在悬崖一样斜率较大的区域,这是由于时间步上几个较大的权重相乘导致的。当参数接近这样的悬崖区域时,如果更新梯度不足够小,很有可能就会直接跳过这样的悬崖结构,然后被弹射到非常远的地方。梯度裁剪(gradient clipping),是这类问题的常用解决办法。它的核心思想就是根据目标函数的光滑程度对梯度进行缩放[1]

本文介绍来自MIT的一篇ICLR2020满分论文《Why gradient clipping accelerates training: A theoretical justification for adaptivity》。顾名思义,这篇论文就是分析为什么梯度裁剪能加速深度学习的训练过程。原文很长,公式很多,还有不少研究复杂性的概念,说实话对笔者来说里边的大部分内容也是懵的,不过大概能捕捉到它的核心思想:引入了比常用的L约束更宽松的约束条件,从新的条件出发论证了梯度裁剪的必要性。本文就是来简单描述一下这个过程,供读者参考。

论文链接:https://arxiv.org/pdf/1905.11881.pdf

Arxiv访问慢的小伙伴也可以在订阅号后台回复关键词【0615】下载论文PDF。

梯度裁剪

假设需要最小化的函数为

,

就是优化参数,那么梯度下降的更新公式就是(滑动查看完整公式)

其中

就是学习率。而所谓梯度裁剪(gradient clipping),就是根据梯度的模长来对更新量做一个缩放,比如

或者

其中 

,是一个常数。这两种方式都被视为梯度裁剪,总的来说就是控制更新量的模长不超过一个常数。其实从下面的不等式就可以看到其实两者基本是等价的:

L约束

有不少优化器相关的理论结果,在其证明中都假设了待优化函数

的梯度满足如下的L约束:

由于 

是梯度的波动程度,实际上衡量的就是 

的光滑程度,所以上述约束也称为“L光滑性条件(L-smooth)”[2]。值得提醒的是,不同的场景可能会需要不同的L约束,比如有时候我们要假设模型输出关于输入满足L约束,有时候我们要假设模型输出关于参数满足L约束,而上面假设的是模型 loss 的梯度关于参数满足L约束。如果条件 (5) 成立,那么很多优化问题都将大大简化。因为我们可以证明[3]:

对于梯度下降来说,

,代入上式得到

因此,为了保证每一步优化都使得 

下降,一个充分条件是 

,即 

,而 

的最小值在 

时取到,所以只需要让学习率为 

,那么每步迭代都可以使得 

下降,并且下降速度最快。

放松约束

条件 (5) 还可以带来很多漂亮的结果,然而问题是在很多实际优化问题中条件 (5) 并不成立,比如四次函数 

。这就导致了理论与实际的差距。而本文要介绍的论文,则引入了一个新的更宽松的约束:

也就是将常数 

换成动态的 

,原文称之为“(L0, L1)-smooth”,这里也称为“(L0, L1)约束”。显然这个条件就宽松多了,比如可以检验 

是满足这个条件的,因此基于此条件所推导出的理论结果适用范围会更广。

在新的约束之下,不等式 (6) 依旧是成立的,只不过

换成对应的动态项:

代入

得到

所以很明显了,现在要保证每一步下降,那么就要求

以及最优学习率是

这就导出了梯度裁剪 (3)。而保证了每一步都下降,那么就意味着在优化过程中每一步都没有做无用功,所以也就加速了训练过程。

作者们是怎么提出这个条件 (8) 的呢?论文中说是做实验观察出来的:观察到损失函数的光滑程度与梯度模长呈“线性相关”关系.png,如下图所示。但笔者感觉吧,至少应该还有些从结果反推的成分在里边,不然谁那么无聊会去观察这两者的关系呢?

文章小结

本文简要介绍了ICLR2020的一篇分析梯度裁剪的满分论文,主要思路是引入了更宽松普适的假设条件,在新的条件下能体现出了梯度裁剪的必要性,并且由于放松了传统的约束,因此理论结果的适用范围更广,这也就表明,梯度裁剪确实是很多场景下都适用的技巧之一。

参考文献

[1]

参考文献 lan Goodfellow et. al, "Deep Learning", MIT press, 2016

[2]

关于L约束可以作者其他博客: 《深度学习中的Lipschitz约束:泛化与生成模型》、《BN究竟起了什么作用?一个闭门造车的分析》。

[3]

证明过程可参考https://kexue.fm/archives/6992。

  • 万能的BERT连文本纠错也不放过

  • 面试必备!卖萌屋算法工程师思维导图—统计机器学习篇

  • 告别自注意力,谷歌为Transformer打造新内核Synthesizer

  • NLP中的少样本困境问题探究

  • ACL20 | 让笨重的BERT问答匹配模型变快!

  • 7款优秀Vim插件帮你打造完美IDE

  • 卖萌屋原创专辑首发,算法镇魂三部曲!

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦



推荐阅读
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 干货 | 携程AI推理性能的自动化优化实践
    作者简介携程度假AI研发团队致力于为携程旅游事业部提供丰富的AI技术产品,其中性能优化组为AI模型提供全方位的优化方案,提升推理性能降低成本࿰ ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 建立分类感知器二元模型对样本数据进行分类
    本文介绍了建立分类感知器二元模型对样本数据进行分类的方法。通过建立线性模型,使用最小二乘、Logistic回归等方法进行建模,考虑到可能性的大小等因素。通过极大似然估计求得分类器的参数,使用牛顿-拉菲森迭代方法求解方程组。同时介绍了梯度上升算法和牛顿迭代的收敛速度比较。最后给出了公式法和logistic regression的实现示例。 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
  • 2017亚马逊人工智能奖公布:他们的AI有什么不同?
    事实上,在我们周围,“人工智能”让一切都变得更“智能”极具讽刺意味。随着人类与机器智能之间的界限变得模糊,我们的世界正在变成一个机器 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 3年半巨亏242亿!商汤高估了深度学习,下错了棋?
    转自:新智元三年半研发开支近70亿,累计亏损242亿。AI这门生意好像越来越不好做了。近日,商汤科技已向港交所递交IPO申请。招股书显示& ... [详细]
  • Two Sigma人均22万英镑~
    近期原创文章: ... [详细]
  • GPT-3发布,动动手指就能自动生成代码的神器来了!
    近日,OpenAI发布了最新的NLP模型GPT-3,该模型在GitHub趋势榜上名列前茅。GPT-3使用的数据集容量达到45TB,参数个数高达1750亿,训练好的模型需要700G的硬盘空间来存储。一位开发者根据GPT-3模型上线了一个名为debuid的网站,用户只需用英语描述需求,前端代码就能自动生成。这个神奇的功能让许多程序员感到惊讶。去年,OpenAI在与世界冠军OG战队的表演赛中展示了他们的强化学习模型,在限定条件下以2:0完胜人类冠军。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • OCR:用字符识别方法将形状翻译成计算机文字的过程Matlab:商业数学软件;CUDA:CUDA™是一种由NVIDIA推 ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
author-avatar
被爱的小花花_
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有