热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

GBDT算法解析

在网上看到一篇对从代码层面理解gbdt比较好的文章,转载记录一下:GBDT(GradientBoostingDecisionTree)又叫MART&#x

在网上看到一篇对从代码层面理解gbdt比较好的文章,转载记录一下:

       

       GBDT(Gradient Boosting Decision Tree) 又叫 MART(Multiple Additive Regression Tree),是一种迭代的决策树算法,该算法由多棵决策树组成,所有树的结论累加起来做最终答案。它在被提出之初就和SVM一起被认为是泛化能力(generalization)较强的算法。近些年更因为被用于搜索排序的机器学习模型而引起大家关注。

 

后记:发现GBDT除了我描述的残差版本外还有另一种GBDT描述,两者大概相同,但求解方法(Gradient应用)不同。其区别和另一版本的介绍链接见这里。由于另一版本介绍博客中亦有不少错误,建议大家还是先看本篇,再跳到另一版本描述,这个顺序当能两版本都看懂。


第1~4节:GBDT算法内部究竟是如何工作的?

第5节:它可以用于解决哪些问题?

第6节:它又是怎样应用于搜索排序的呢? 

 

在此先给出我比较推荐的两篇英文文献,喜欢英文原版的同学可直接阅读:

【1】Boosting Decision Tree入门教程 http://www.schonlau.net/publication/05stata_boosting.pdf

【2】LambdaMART用于搜索排序入门教程 http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf

 

GBDT主要由三个概念组成:Regression Decistion Tree(即DT),Gradient Boosting(即GB),Shrinkage (算法的一个重要演进分枝,目前大部分源码都按该版本实现)。搞定这三个概念后就能明白GBDT是如何工作的,要继续理解它如何用于搜索排序则需要额外理解RankNet概念,之后便功德圆满。下文将逐个碎片介绍,最终把整张图拼出来。

 

一、 DT:回归树 Regression Decision Tree

提起决策树(DT, Decision Tree) 绝大部分人首先想到的就是C4.5分类决策树。但如果一开始就把GBDT中的树想成分类树,那就是一条歪路走到黑,一路各种坑,最终摔得都要咯血了还是一头雾水说的就是LZ自己啊有木有。咳嗯,所以说千万不要以为GBDT是很多棵分类树。决策树分为两大类,回归树和分类树。前者用于预测实数值,如明天的温度、用户的年龄、网页的相关程度;后者用于分类标签值,如晴天/阴天/雾/雨、用户性别、网页是否是垃圾页面。这里要强调的是,前者的结果加减是有意义的,如10岁+5岁-3岁=12岁,后者则无意义,如男+男+女=到底是男是女? GBDT的核心在于累加所有树的结果作为最终结果,就像前面对年龄的累加(-3是加负3),而分类树的结果显然是没办法累加的,所以GBDT中的树都是回归树,不是分类树,这点对理解GBDT相当重要(尽管GBDT调整后也可用于分类但不代表GBDT的树是分类树)。那么回归树是如何工作的呢?

 

下面我们以对人的性别判别/年龄预测为例来说明,每个instance都是一个我们已知性别/年龄的人,而feature则包括这个人上网的时长、上网的时段、网购所花的金额等。

 

作为对比&#xff0c;先说分类树&#xff0c;我们知道C4.5分类树在每次分枝时&#xff0c;是穷举每一个feature的每一个阈值&#xff0c;找到使得按照feature<&#61;阈值&#xff0c;和feature>阈值分成的两个分枝的熵最大的feature和阈值&#xff08;熵最大的概念可理解成尽可能每个分枝的男女比例都远离1:1&#xff09;&#xff0c;按照该标准分枝得到两个新节点&#xff0c;用同样方法继续分枝直到所有人都被分入性别唯一的叶子节点&#xff0c;或达到预设的终止条件&#xff0c;若最终叶子节点中的性别不唯一&#xff0c;则以多数人的性别作为该叶子节点的性别。

 

回归树总体流程也是类似&#xff0c;不过在每个节点&#xff08;不一定是叶子节点&#xff09;都会得一个预测值&#xff0c;以年龄为例&#xff0c;该预测值等于属于这个节点的所有人年龄的平均值。分枝时穷举每一个feature的每个阈值找最好的分割点&#xff0c;但衡量最好的标准不再是最大熵&#xff0c;而是最小化均方差--即&#xff08;每个人的年龄-预测年龄&#xff09;^2 的总和 / N&#xff0c;或者说是每个人的预测误差平方和 除以 N。这很好理解&#xff0c;被预测出错的人数越多&#xff0c;错的越离谱&#xff0c;均方差就越大&#xff0c;通过最小化均方差能够找到最靠谱的分枝依据。分枝直到每个叶子节点上人的年龄都唯一&#xff08;这太难了&#xff09;或者达到预设的终止条件&#xff08;如叶子个数上限&#xff09;&#xff0c;若最终叶子节点上人的年龄不唯一&#xff0c;则以该节点上所有人的平均年龄做为该叶子节点的预测年龄。若还不明白可以Google "Regression Tree"&#xff0c;或阅读本文的第一篇论文中Regression Tree部分。

 

二、 GB&#xff1a;梯度迭代 Gradient Boosting

好吧&#xff0c;我起了一个很大的标题&#xff0c;但事实上我并不想多讲Gradient Boosting的原理&#xff0c;因为不明白原理并无碍于理解GBDT中的Gradient Boosting。喜欢打破砂锅问到底的同学可以阅读这篇英文wikihttp://en.wikipedia.org/wiki/Gradient_boosted_trees#Gradient_tree_boosting

 

Boosting&#xff0c;迭代&#xff0c;即通过迭代多棵树来共同决策。这怎么实现呢&#xff1f;难道是每棵树独立训练一遍&#xff0c;比如A这个人&#xff0c;第一棵树认为是10岁&#xff0c;第二棵树认为是0岁&#xff0c;第三棵树认为是20岁&#xff0c;我们就取平均值10岁做最终结论&#xff1f;--当然不是&#xff01;且不说这是投票方法并不是GBDT&#xff0c;只要训练集不变&#xff0c;独立训练三次的三棵树必定完全相同&#xff0c;这样做完全没有意义。之前说过&#xff0c;GBDT是把所有树的结论累加起来做最终结论的&#xff0c;所以可以想到每棵树的结论并不是年龄本身&#xff0c;而是年龄的一个累加量。GBDT的核心就在于&#xff0c;每一棵树学的是之前所有树结论和的残差&#xff0c;这个残差就是一个加预测值后能得真实值的累加量。比如A的真实年龄是18岁&#xff0c;但第一棵树的预测年龄是12岁&#xff0c;差了6岁&#xff0c;即残差为6岁。那么在第二棵树里我们把A的年龄设为6岁去学习&#xff0c;如果第二棵树真的能把A分到6岁的叶子节点&#xff0c;那累加两棵树的结论就是A的真实年龄&#xff1b;如果第二棵树的结论是5岁&#xff0c;则A仍然存在1岁的残差&#xff0c;第三棵树里A的年龄就变成1岁&#xff0c;继续学。这就是Gradient Boosting在GBDT中的意义&#xff0c;简单吧。

 

三、 GBDT工作过程实例。

还是年龄预测&#xff0c;简单起见训练集只有4个人&#xff0c;A,B,C,D&#xff0c;他们的年龄分别是14,16,24,26。其中A、B分别是高一和高三学生&#xff1b;C,D分别是应届毕业生和工作两年的员工。如果是用一棵传统的回归决策树来训练&#xff0c;会得到如下图1所示结果&#xff1a;


 

现在我们使用GBDT来做这件事&#xff0c;由于数据太少&#xff0c;我们限定叶子节点做多有两个&#xff0c;即每棵树都只有一个分枝&#xff0c;并且限定只学两棵树。我们会得到如下图2所示结果&#xff1a;

 

在第一棵树分枝和图1一样&#xff0c;由于A,B年龄较为相近&#xff0c;C,D年龄较为相近&#xff0c;他们被分为两拨&#xff0c;每拨用平均年龄作为预测值。此时计算残差&#xff08;残差的意思就是&#xff1a; A的预测值 &#43; A的残差 &#61; A的实际值&#xff09;&#xff0c;所以A的残差就是16-15&#61;1&#xff08;注意&#xff0c;A的预测值是指前面所有树累加的和&#xff0c;这里前面只有一棵树所以直接是15&#xff0c;如果还有树则需要都累加起来作为A的预测值&#xff09;。进而得到A,B,C,D的残差分别为-1,1&#xff0c;-1,1。然后我们拿残差替代A,B,C,D的原值&#xff0c;到第二棵树去学习&#xff0c;如果我们的预测值和它们的残差相等&#xff0c;则只需把第二棵树的结论累加到第一棵树上就能得到真实年龄了。这里的数据显然是我可以做的&#xff0c;第二棵树只有两个值1和-1&#xff0c;直接分成两个节点。此时所有人的残差都是0&#xff0c;即每个人都得到了真实的预测值。

 

换句话说&#xff0c;现在A,B,C,D的预测值都和真实年龄一致了。Perfect!&#xff1a;

A: 14岁高一学生&#xff0c;购物较少&#xff0c;经常问学长问题&#xff1b;预测年龄A &#61; 15 – 1 &#61; 14

B: 16岁高三学生&#xff1b;购物较少&#xff0c;经常被学弟问问题&#xff1b;预测年龄B &#61; 15 &#43; 1 &#61; 16

C: 24岁应届毕业生&#xff1b;购物较多&#xff0c;经常问师兄问题&#xff1b;预测年龄C &#61; 25 – 1 &#61; 24

D: 26岁工作两年员工&#xff1b;购物较多&#xff0c;经常被师弟问问题&#xff1b;预测年龄D &#61; 25 &#43; 1 &#61; 26 

 

那么哪里体现了Gradient呢&#xff1f;其实回到第一棵树结束时想一想&#xff0c;无论此时的cost function是什么&#xff0c;是均方差还是均差&#xff0c;只要它以误差作为衡量标准&#xff0c;残差向量(-1, 1, -1, 1)都是它的全局最优方向&#xff0c;这就是Gradient。

 

讲到这里我们已经把GBDT最核心的概念、运算过程讲完了&#xff01;没错就是这么简单。不过讲到这里很容易发现三个问题&#xff1a;

 

1&#xff09;既然图1和图2 最终效果相同&#xff0c;为何还需要GBDT呢&#xff1f;

答案是过拟合。过拟合是指为了让训练集精度更高&#xff0c;学到了很多”仅在训练集上成立的规律“&#xff0c;导致换一个数据集当前规律就不适用了。其实只要允许一棵树的叶子节点足够多&#xff0c;训练集总是能训练到100%准确率的&#xff08;大不了最后一个叶子上只有一个instance)。在训练精度和实际精度&#xff08;或测试精度&#xff09;之间&#xff0c;后者才是我们想要真正得到的。

我们发现图1为了达到100%精度使用了3个feature&#xff08;上网时长、时段、网购金额&#xff09;&#xff0c;其中分枝“上网时长>1.1h” 很显然已经过拟合了&#xff0c;这个数据集上A,B也许恰好A每天上网1.09h, B上网1.05小时&#xff0c;但用上网时间是不是>1.1小时来判断所有人的年龄很显然是有悖常识的&#xff1b;

相对来说图2的boosting虽然用了两棵树 &#xff0c;但其实只用了2个feature就搞定了&#xff0c;后一个feature是问答比例&#xff0c;显然图2的依据更靠谱。&#xff08;当然&#xff0c;这里是LZ故意做的数据&#xff0c;所以才能靠谱得如此狗血。实际中靠谱不靠谱总是相对的&#xff09; Boosting的最大好处在于&#xff0c;每一步的残差计算其实变相地增大了分错instance的权重&#xff0c;而已经分对的instance则都趋向于0。这样后面的树就能越来越专注那些前面被分错的instance。就像我们做互联网&#xff0c;总是先解决60%用户的需求凑合着&#xff0c;再解决35%用户的需求&#xff0c;最后才关注那5%人的需求&#xff0c;这样就能逐渐把产品做好&#xff0c;因为不同类型用户需求可能完全不同&#xff0c;需要分别独立分析。如果反过来做&#xff0c;或者刚上来就一定要做到尽善尽美&#xff0c;往往最终会竹篮打水一场空。

 

2&#xff09;Gradient呢&#xff1f;不是“G”BDT么&#xff1f;

 到目前为止&#xff0c;我们的确没有用到求导的Gradient。在当前版本GBDT描述中&#xff0c;的确没有用到Gradient&#xff0c;该版本用残差作为全局最优的绝对方向&#xff0c;并不需要Gradient求解.


 

3&#xff09;这不是boosting吧&#xff1f;Adaboost可不是这么定义的。

这是boosting&#xff0c;但不是Adaboost。GBDT不是Adaboost Decistion Tree。就像提到决策树大家会想起C4.5&#xff0c;提到boost多数人也会想到Adaboost。Adaboost是另一种boost方法&#xff0c;它按分类对错&#xff0c;分配不同的weight&#xff0c;计算cost function时使用这些weight&#xff0c;从而让“错分的样本权重越来越大&#xff0c;使它们更被重视”。Bootstrap也有类似思想&#xff0c;它在每一步迭代时不改变模型本身&#xff0c;也不计算残差&#xff0c;而是从N个instance训练集中按一定概率重新抽取N个instance出来&#xff08;单个instance可以被重复sample&#xff09;&#xff0c;对着这N个新的instance再训练一轮。由于数据集变了迭代模型训练结果也不一样&#xff0c;而一个instance被前面分错的越厉害&#xff0c;它的概率就被设的越高&#xff0c;这样就能同样达到逐步关注被分错的instance&#xff0c;逐步完善的效果。Adaboost的方法被实践证明是一种很好的防止过拟合的方法&#xff0c;但至于为什么则至今没从理论上被证明。GBDT也可以在使用残差的同时引入Bootstrap re-sampling&#xff0c;GBDT多数实现版本中也增加的这个选项&#xff0c;但是否一定使用则有不同看法。re-sampling一个缺点是它的随机性&#xff0c;即同样的数据集合训练两遍结果是不一样的&#xff0c;也就是模型不可稳定复现&#xff0c;这对评估是很大挑战&#xff0c;比如很难说一个模型变好是因为你选用了更好的feature&#xff0c;还是由于这次sample的随机因素。

 

 

四、Shrinkage 

Shrinkage&#xff08;缩减&#xff09;的思想认为&#xff0c;每次走一小步逐渐逼近结果的效果&#xff0c;要比每次迈一大步很快逼近结果的方式更容易避免过拟合。即它不完全信任每一个棵残差树&#xff0c;它认为每棵树只学到了真理的一小部分&#xff0c;累加的时候只累加一小部分&#xff0c;通过多学几棵树弥补不足。用方程来看更清晰&#xff0c;即

没用Shrinkage时&#xff1a;&#xff08;yi表示第i棵树上y的预测值&#xff0c; y(1~i)表示前i棵树y的综合预测值&#xff09;

y(i&#43;1) &#61; 残差(y1~yi)&#xff0c; 其中&#xff1a; 残差(y1~yi) &#61;  y真实值 - y(1 ~ i)

y(1 ~ i) &#61; SUM(y1, ..., yi)

Shrinkage不改变第一个方程&#xff0c;只把第二个方程改为&#xff1a; 

y(1 ~ i) &#61; y(1 ~ i-1) &#43; step * yi

 

即Shrinkage仍然以残差作为学习目标&#xff0c;但对于残差学习出来的结果&#xff0c;只累加一小部分&#xff08;step*残差&#xff09;逐步逼近目标&#xff0c;step一般都比较小&#xff0c;如0.01~0.001&#xff08;注意该step非gradient的step&#xff09;&#xff0c;导致各个树的残差是渐变的而不是陡变的。直觉上这也很好理解&#xff0c;不像直接用残差一步修复误差&#xff0c;而是只修复一点点&#xff0c;其实就是把大步切成了很多小步。本质上&#xff0c;Shrinkage为每棵树设置了一个weight&#xff0c;累加时要乘以这个weight&#xff0c;但和Gradient并没有关系。这个weight就是step。就像Adaboost一样&#xff0c;Shrinkage能减少过拟合发生也是经验证明的&#xff0c;目前还没有看到从理论的证明。


五、 GBDT的适用范围

该版本GBDT几乎可用于所有回归问题&#xff08;线性/非线性&#xff09;&#xff0c;相对logistic regression仅能用于线性回归&#xff0c;GBDT的适用面非常广。亦可用于二分类问题&#xff08;设定阈值&#xff0c;大于阈值为正例&#xff0c;反之为负例&#xff09;。

 

六、 搜索引擎排序应用 RankNet

搜索排序关注各个doc的顺序而不是绝对值&#xff0c;所以需要一个新的cost function&#xff0c;而RankNet基本就是在定义这个cost function&#xff0c;它可以兼容不同的算法&#xff08;GBDT、神经网络...&#xff09;。


实际的搜索排序使用的是LambdaMART算法&#xff0c;必须指出的是由于这里要使用排序需要的cost function&#xff0c;LambdaMART迭代用的并不是残差。Lambda在这里充当替代残差的计算方法&#xff0c;它使用了一种类似Gradient*步长模拟残差的方法。这里的MART在求解方法上和之前说的残差略有不同&#xff0c;其区别描述见这里。


就像所有的机器学习一样&#xff0c;搜索排序的学习也需要训练集&#xff0c;这里一般是用人工标注实现&#xff0c;即对每一个(query,doc) pair给定一个分值&#xff08;如1,2,3,4&#xff09;,分值越高表示越相关&#xff0c;越应该排到前面。然而这些绝对的分值本身意义不大&#xff0c;例如你很难说1分和2分文档的相关程度差异是1分和3分文档差距的一半。相关度本身就是一个很主观的评判&#xff0c;标注人员无法做到这种定量标注&#xff0c;这种标准也无法制定。但标注人员很容易做到的是”AB都不错&#xff0c;但文档A比文档B更相关&#xff0c;所以A是4分&#xff0c;B是3分“。RankNet就是基于此制定了一个学习误差衡量方法&#xff0c;即cost function。具体而言&#xff0c;RankNet对任意两个文档A,B&#xff0c;通过它们的人工标注分差&#xff0c;用sigmoid函数估计两者顺序和逆序的概率P1。然后同理用机器学习到的分差计算概率P2&#xff08;sigmoid的好处在于它允许机器学习得到的分值是任意实数值&#xff0c;只要它们的分差和标准分的分差一致&#xff0c;P2就趋近于P1&#xff09;。这时利用P1和P2求的两者的交叉熵&#xff0c;该交叉熵就是cost function。它越低说明机器学得的当前排序越趋近于标注排序。为了体现NDCG的作用&#xff08;NDCG是搜索排序业界最常用的评判标准&#xff09;&#xff0c;RankNet还在cost function中乘以了NDCG。


好&#xff0c;现在我们有了cost function&#xff0c;而且它是和各个文档的当前分值yi相关的&#xff0c;那么虽然我们不知道它的全局最优方向&#xff0c;但可以求导求Gradient&#xff0c;Gradient即每个文档得分的一个下降方向组成的N维向量&#xff0c;N为文档个数&#xff08;应该说是query-doc pair个数&#xff09;。这里仅仅是把”求残差“的逻辑替换为”求梯度“&#xff0c;可以这样想&#xff1a;梯度方向为每一步最优方向&#xff0c;累加的步数多了&#xff0c;总能走到局部最优点&#xff0c;若该点恰好为全局最优点&#xff0c;那和用残差的效果是一样的。这时套到之前讲的逻辑&#xff0c;GDBT就已经可以上了。那么最终排序怎么产生呢&#xff1f;很简单&#xff0c;每个样本通过Shrinkage累加都会得到一个最终得分&#xff0c;直接按分数从大到小排序就可以了&#xff08;因为机器学习产生的是实数域的预测分&#xff0c;极少会出现在人工标注中常见的两文档分数相等的情况&#xff0c;几乎不同考虑同分文档的排序方式&#xff09;


另外&#xff0c;如果feature个数太多&#xff0c;每一棵回归树都要耗费大量时间&#xff0c;这时每个分支时可以随机抽一部分feature来遍历求最优&#xff08;ELF源码实现方式&#xff09;。


推荐阅读
  • 本文详细介绍了 InfluxDB、collectd 和 Grafana 的安装与配置流程。首先,按照启动顺序依次安装并配置 InfluxDB、collectd 和 Grafana。InfluxDB 作为时序数据库,用于存储时间序列数据;collectd 负责数据的采集与传输;Grafana 则用于数据的可视化展示。文中提供了 collectd 的官方文档链接,便于用户参考和进一步了解其配置选项。通过本指南,读者可以轻松搭建一个高效的数据监控系统。 ... [详细]
  • 本文介绍了几种常用的图像相似度对比方法,包括直方图方法、图像模板匹配、PSNR峰值信噪比、SSIM结构相似性和感知哈希算法。每种方法都有其优缺点,适用于不同的应用场景。 ... [详细]
  • 在机器学习领域,深入探讨了概率论与数理统计的基础知识,特别是这些理论在数据挖掘中的应用。文章重点分析了偏差(Bias)与方差(Variance)之间的平衡问题,强调了方差反映了不同训练模型之间的差异,例如在K折交叉验证中,不同模型之间的性能差异显著。此外,还讨论了如何通过优化模型选择和参数调整来有效控制这一平衡,以提高模型的泛化能力。 ... [详细]
  • 在对WordPress Duplicator插件0.4.4版本的安全评估中,发现其存在跨站脚本(XSS)攻击漏洞。此漏洞可能被利用进行恶意操作,建议用户及时更新至最新版本以确保系统安全。测试方法仅限于安全研究和教学目的,使用时需自行承担风险。漏洞编号:HTB23162。 ... [详细]
  • 本文详细介绍了如何在 Linux 系统上安装 JDK 1.8、MySQL 和 Redis,并提供了相应的环境配置和验证步骤。 ... [详细]
  • com.sun.javadoc.PackageDoc.exceptions()方法的使用及代码示例 ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • Python 数据可视化实战指南
    本文详细介绍如何使用 Python 进行数据可视化,涵盖从环境搭建到具体实例的全过程。 ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • 本文介绍如何使用 Python 的 DOM 和 SAX 方法解析 XML 文件,并通过示例展示了如何动态创建数据库表和处理大量数据的实时插入。 ... [详细]
  • 本文对比了杜甫《喜晴》的两种英文翻译版本:a. Pleased with Sunny Weather 和 b. Rejoicing in Clearing Weather。a 版由 alexcwlin 翻译并经 Adam Lam 编辑,b 版则由哈佛大学的宇文所安教授 (Prof. Stephen Owen) 翻译。 ... [详细]
  • javascript分页类支持页码格式
    前端时间因为项目需要,要对一个产品下所有的附属图片进行分页显示,没考虑ajax一张张请求,所以干脆一次性全部把图片out,然 ... [详细]
  • poj 3352 Road Construction ... [详细]
  • 在软件开发过程中,经常需要将多个项目或模块进行集成和调试,尤其是当项目依赖于第三方开源库(如Cordova、CocoaPods)时。本文介绍了如何在Xcode中高效地进行多项目联合调试,分享了一些实用的技巧和最佳实践,帮助开发者解决常见的调试难题,提高开发效率。 ... [详细]
  • R语言中向量(Vector)数据类型的元素索引与访问:利用中括号[]和赋值操作符在向量末尾追加数据以扩展其长度
    在R语言中,向量(Vector)数据类型的元素可以通过中括号 `[]` 进行索引和访问。此外,利用中括号和赋值操作符,可以在向量的末尾追加新数据,从而动态地扩展向量的长度。这种方法不仅简洁高效,还能灵活地管理向量中的数据。 ... [详细]
author-avatar
ym_泳梅
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有