热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【模型压缩】知识蒸馏经典解读

作者|小小字节跳动整理|NewBeeNLP写在前面知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,由于其简单,有效

作者 | 小小@字节跳动

整理 | NewBeeNLP

写在前面

知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,由于其简单,有效,在工业界被广泛应用。这一技术的理论来自于2015年Hinton发表的一篇神作:Distilling the Knowledge in a Neural Network[1]

Knowledge Distillation,简称KD,顾名思义,就是将已经训练好的模型包含的知识(”Knowledge”),蒸馏(“Distill”)提取到另一个模型里面去。今天,我们就来简单读一下这篇论文,力求用简单的语言描述论文作者的主要思想。在本文中,我们将从背景和动机讲起,然后着重介绍“知识蒸馏”的方法,最后我会讨论“温度“这个名词:

  • 「温度」: 我们都知道“蒸馏”需要在高温下进行,那么这个“蒸馏”的温度代表了什么,又是如何选取合适的温度?

介绍

论文提出的背景:

虽然在一般情况下,我们不会去区分训练和部署使用的模型,但是训练和部署之间存在着一定的不一致性:

  • 在训练过程中,我们需要使用复杂的模型,大量的计算资源,以便从非常大、高度冗余的数据集中提取出信息。在实验中,效果最好的模型往往规模很大,甚至由多个模型集成得到。而大模型不方便部署到服务中去,常见的瓶颈如下:

    • 推断速度慢

    • 对部署资源要求高(内存,显存等)

  • 在部署时,我们对延迟以及计算资源都有着严格的限制。

因此,模型压缩(在保证性能的前提下减少模型的参数量)成为了一个重要的问题。而「模型蒸馏」属于模型压缩的一种方法。

插句题外话,我们可以从模型参数量和训练数据量之间的相对关系来理解underfitting和overfitting。AI领域的从业者可能对此已经习以为常,但是为了力求让小白也能读懂本文,还是引用我同事的解释(我印象很深)形象地说明一下:

模型就像一个容器,训练数据中蕴含的知识就像是要装进容器里的水。

  • 当数据知识量(水量)超过模型所能建模的范围时(容器的容积),加再多的数据也不能提升效果(水再多也装不进容器),因为模型的表达空间有限(容器容积有限),就会造成underfitting;

  • 而当模型的参数量大于已有知识所需要的表达空间时(容积大于水量,水装不满容器),就会造成overfitting,即模型的bias会增大(想象一下摇晃半满的容器,里面水的形状是不稳定的)。

“思想歧路”

上面容器和水的比喻非常经典和贴切,但是会引起一个误解: 人们在直觉上会觉得,要保留相近的知识量,必须保留相近规模的模型。也就是说,一个模型的参数量基本决定了其所能捕获到的数据内蕴含的“知识”的量。

这样的想法是基本正确的,但是需要注意的是:

  1. 模型的参数量和其所能捕获的“知识“量之间并非稳定的线性关系(下图中的1),而是接近边际收益逐渐减少的一种增长曲线(下图中的2和3)

  2. 完全相同的模型架构和模型参数量,使用完全相同的训练数据,能捕获的“知识”量并不一定完全相同,另一个关键因素是训练的方法。合适的训练方法可以使得在模型参数总量比较小时,尽可能地获取到更多的“知识”(下图中的3与2曲线的对比).

知识蒸馏的理论依据

Teacher Model和Student Model

知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:

  1. 原始模型训练: 训练”Teacher模型”, 简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对”Teacher模型”不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。

  2. 精简模型训练: 训练”Student模型”, 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。

在本论文中,作者将问题限定在「分类问题」下,或者其他本质上属于分类问题的问题,该类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。

知识蒸馏的关键点

如果回归机器学习最最基础的理论,我们可以很清楚地意识到一点(而这一点往往在我们深入研究机器学习之后被忽略): 机器学习「最根本的目的」在于训练出在某个问题上泛化能力强的模型。

  • 泛化能力强: 在某问题的所有数据上都能很好地反应输入和输出之间的关系,无论是训练数据,还是测试数据,还是任何属于该问题的未知数据。

而现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。

而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。

一个很直白且高效的迁移泛化能力的方法就是使用softmax层输出的类别的概率来作为“soft target”。

  1. 传统training过程(hard targets): 对ground truth求极大似然

  2. KD的training过程(soft targets): 用large model的class probabilities作为soft targets

上图: Hard Target 下图: Soft Target

「为什么?」
softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。

举个例子来说明一下: 在手写体数字识别任务MNIST中,输出类别有10个。

假设某个输入的“2”更加形似”3”,softmax的输出值中”3”对应的概率为0.1,而其他负标签对应的值都很小,而另一个”2”更加形似”7”,”7”对应的概率为0.1。这两个”2”对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。

两个”2“的hard target相同而soft target不同

这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。

softmax函数

先回顾一下原始的softmax函数:

但要是直接使用softmax层的输出值作为soft target, 这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此”温度”这个变量就派上了用场。

下面的公式时加了温度这个变量之后的softmax函数:

  • 这里的T就是「温度」

  • 原来的softmax函数是T = 1的特例。T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

知识蒸馏的具体方法

通用的知识蒸馏方法

  • 第一步是训练Net-T;第二步是在高温T下,蒸馏Net-T的知识到Net-S

训练Net-T的过程很简单,下面详细讲讲第二步:高温蒸馏的过程。高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到。示意图如上。

  • : Net-T的logits

  • : Net-S的logits

  • : Net-T的在温度=T下的softmax输出在第i类上的值

  • : Net-S的在温度=T下的softmax输出在第i类上的值

  • : 在第i类上的ground truth值,

    , 正标签取1,负标签取0.

  • : 总标签数量

  • Net-T 和 Net-S同时输入 transfer set (这里可以直接复用训练Net-T用到的training set), 用Net-T产生的softmax distribution (with high temperature) 来作为soft target,Net-S在相同温度T下的softmax输出和soft target的cross entropy就是「Loss函数的第一部分」

    .

, 其中,

  • Net-S在温度=1下的softmax输出和ground truth的cross entropy就是「Loss函数的第二部分」

    .

, 其中.

  • 第二部分Loss

    的必要性其实很好理解: Net-T也有一定的错误率,使用ground truth可以有效降低错误被传播给Net-S的可能。打个比方,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。

「讨论」

  • 实验发现第二部分所占比重比较小的时候,能产生最好的结果,这是一个经验的结论。一个可能的原因是,由于soft target产生的gradient与hard target产生的gradient之间有与 T 相关的比值。



另外,。因此在同时使用soft target和hard target的时候,需要在soft target之前乘上

的系数,这样才能保证soft target和hard target贡献的梯度量基本一致。

  • 「注意」: 在Net-S训练完毕后,做inference时其softmax的温度T要恢复到1.

一种特殊情形: 直接match logits(不经过softmax)

直接match logits指的是,直接使用softmax层的输入logits(而不是输出)作为soft targets,需要最小化的目标函数是Net-T和Net-S的logits之间的平方差。

「直接match logits的做法是」

「的情况下的特殊情形。」

由单个case贡献的loss,推算出对应在Net-S每个logit

上的gradient:

时,我们使用

来近似

,于是得到

如果再加上logits是零均值的假设

那么上面的公式可以简化成

也就是等价于minimise下面的损失函数

关于”温度”的讨论

【问题】 我们都知道“蒸馏”需要在高温下进行,那么这个“蒸馏”的温度代表了什么,又是如何选取合适的温度?

随着温度T的增大,概率分布的熵逐渐增大

温度的特点

在回答这个问题之前,先讨论一下「温度T的特点」

  1. 原始的softmax函数是

    时的特例,

    时,概率分布比原始更“陡峭”,

    时,概率分布比原始更“平缓”。

  2. 温度越高,softmax上各个值的分布就越平均(思考极端情况: (i)

    , 此时softmax的值是平均分布的;(ii)

    ,此时softmax的值就相当于

    ,即最大的概率处的值趋近于1,而其他值趋近于0)

  3. 不管温度T怎么取值,Soft target都有忽略小的

    携带的信息的倾向

温度代表了什么,如何选取合适的温度?

温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。

实际上,负标签中包含一定的信息,尤其是那些值显著「高于」平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:

  1. 从有部分信息量的负标签中学习 –> 温度要高一些

  2. 防止受负标签中噪声的影响 –>温度要低一些

总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)

本文参考资料

[1]

Distilling the Knowledge in a Neural Network: https://arxiv.org/abs/1503.02531

[2]

深度压缩之蒸馏模型 - 风雨兼程的文章 - 知乎: https://zhuanlan.zhihu.com/p/24337627>

[3]

知识蒸馏Knowledge Distillation - 船长的文章 - 知乎: https://zhuanlan.zhihu.com/p/83456418

[4]

knowledge-distillation-simplified: https://towardsdatascience.com/knowledge-distillation-simplified-dd4973dbc764

[5]

knowledge_distillation: https://nervanasystems.github.io/distiller/knowledge_distillation.html

END -

添加个人微信,备注:昵称-学校(公司)-方向,即可获得

1. 快速学习深度学习五件套资料

2. 进入高手如云DL&NLP交流群

记得备注呦



推荐阅读
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文讨论了如何优化解决hdu 1003 java题目的动态规划方法,通过分析加法规则和最大和的性质,提出了一种优化的思路。具体方法是,当从1加到n为负时,即sum(1,n)sum(n,s),可以继续加法计算。同时,还考虑了两种特殊情况:都是负数的情况和有0的情况。最后,通过使用Scanner类来获取输入数据。 ... [详细]
  • 本文介绍了九度OnlineJudge中的1002题目“Grading”的解决方法。该题目要求设计一个公平的评分过程,将每个考题分配给3个独立的专家,如果他们的评分不一致,则需要请一位裁判做出最终决定。文章详细描述了评分规则,并给出了解决该问题的程序。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • 开发笔记:实验7的文件读写操作
    本文介绍了使用C++的ofstream和ifstream类进行文件读写操作的方法,包括创建文件、写入文件和读取文件的过程。同时还介绍了如何判断文件是否成功打开和关闭文件的方法。通过本文的学习,读者可以了解如何在C++中进行文件读写操作。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 本文介绍了使用哈夫曼树实现文件压缩和解压的方法。首先对数据结构课程设计中的代码进行了分析,包括使用时间调用、常量定义和统计文件中各个字符时相关的结构体。然后讨论了哈夫曼树的实现原理和算法。最后介绍了文件压缩和解压的具体步骤,包括字符统计、构建哈夫曼树、生成编码表、编码和解码过程。通过实例演示了文件压缩和解压的效果。本文的内容对于理解哈夫曼树的实现原理和应用具有一定的参考价值。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 知识图谱——机器大脑中的知识库
    本文介绍了知识图谱在机器大脑中的应用,以及搜索引擎在知识图谱方面的发展。以谷歌知识图谱为例,说明了知识图谱的智能化特点。通过搜索引擎用户可以获取更加智能化的答案,如搜索关键词"Marie Curie",会得到居里夫人的详细信息以及与之相关的历史人物。知识图谱的出现引起了搜索引擎行业的变革,不仅美国的微软必应,中国的百度、搜狗等搜索引擎公司也纷纷推出了自己的知识图谱。 ... [详细]
  • 本文讨论了编写可保护的代码的重要性,包括提高代码的可读性、可调试性和直观性。同时介绍了优化代码的方法,如代码格式化、解释函数和提炼函数等。还提到了一些常见的坏代码味道,如不规范的命名、重复代码、过长的函数和参数列表等。最后,介绍了如何处理数据泥团和进行函数重构,以提高代码质量和可维护性。 ... [详细]
  • 本文介绍了Oracle存储过程的基本语法和写法示例,同时还介绍了已命名的系统异常的产生原因。 ... [详细]
  • 本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。 ... [详细]
author-avatar
Justine-zhu
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有