热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

梯度下降算法_用人话讲明白梯度下降GradientDescent

文章目录1.梯度2.多元线性回归参数求解3.梯度下降4.梯度下降法求解多元线性回归梯度下降算法在机器学习中出现频率特别高,是非常常用的优化算法。本文借多元线性回归&#


文章目录


1.梯度


2.多元线性回归参数求解


3.梯度下降


4.梯度下降法求解多元线性回归


梯度下降算法在机器学习中出现频率特别高,是非常常用的优化算法。


本文借多元线性回归,用人话解释清楚梯度下降的原理和步骤。


1.梯度


梯度是什么呢?


我们还是从最简单的情况说起,对于一元函数来讲,梯度就是函数的导数


而对于多元函数而言,梯度是一个向量,也就是说,把求得的偏导数以向量的形式写出来,就是梯度


例如,我们在用人话讲明白线性回归LinearRegression一文中,求未知参数β0 和β1 时,对损失函数求偏导,此时的梯度向量为:



a7b7cd87b94f65bd35f43737a8cc85eb.png

其中:



b972211d54d312c139e932d7de6e64f8.png

那篇文章中,因为一元线性回归中只有2个参数,因此令两个偏导数为0,能很容易求得β0 和β1 的解。


但是,这种求导的方法在多元回归的参数求解中就不太实用了,为什么呢?


2.多元线性回归参数求解


多元线性回归方程的一般形式为:



c7a5c3fe914ecbec205641d236facf3c.png

可以简写为矩阵形式(一般加粗表示矩阵或向量):



08a4d38bb8c5e09123cc6fbbccf9de5a.png

其中,



af09f938161ca0544fd9fb5741ef8e53.png

之前我们介绍过一元线性回归的损失函数可以用残差平方和:



c8307be1140b1a061c870cf6b83b03e0.png

代入多元线性回归方程就是:



046ede1e41c4bc31b99946689ec4295f.png

用矩阵形式表示:



86acbb5e3a1ef38200e636c1eb1b71b5.png

上面的展开过程涉及矩阵转置,这里简单提一下矩阵转置相关运算,以免之前学过但是现在忘了:



f659879a1a3f9ba1c3cea101552c830b.png

好了,按照一元线性回归求解析解的思路,现在我们要对Q求导并令导数为0(原谅我懒,后面写公式就不对向量或矩阵加粗了,大家能理解就行):



71fe3760fdd1bab50e55801c6188b8ef.png

上面的推导过程涉及矩阵求导,这里展开讲下,为什么



c78a50ea24d198cd476c1bfcf8dcc517.png

其他几项留给大家举一反三。


首先:



5564678e5d5049d7cfffbb148baf3189.png

为了直观点,我们将YTX 记为A,因为Y是n维列向量,X是n×(p+1)的矩阵,因此YTX是(p+1)维行向量:



c41d720d0fd99764bd96a50f4f78cede.png

那么上面求导可以简写为:



63439bd52ed40602d1ca3dac83b7a091.png

这种形式的矩阵求导属于分母布局,即分子为行向量或者分母为列向量(这里属于后者)。


搞不清楚的可以看看这篇:矩阵求导实例,这里我直接写出标量/列向量求导的公式,如下(y表示标量,X表示列向量):



d259612f532873020a1fc9a1a5250a40.png

根据上式,显然有:



64af91a1af842209f50cb35ad6ebf572.png

前面我们将YTX记为A,那么上面算出来的结果就是AT ,即



c263314e7fe51b3e547214b4c8f4b4dc.png

说了这么多有的没的,最终我想说是的



4ea02e10bd24244641196eb6c94c5487.png

里面涉及到矩阵求逆,但实际问题中可能X没有逆矩阵,这时计算的结果就不够精确


第二个问题就是,如果维度多、样本多,即便有逆矩阵,计算机求解的速度也会很慢


所以,基于上面这两点,一般情况下我们不会用解析解求解法求多元线性回归参数,而是采用梯度下降法,它的计算代价相对更低。


3.梯度下降


好了,重点来了,本文真正要讲的东西终于登场了。


梯度下降,就是通过一步步迭代,让所有偏导函数都下降到最低。如果觉得不好理解,我们就还是以最简单的一元函数为例开始讲。


下图是我用Excel简单画的二次函数图像(看起来有点歪,原谅我懒……懒得调整了……),函数为


y=x^2 ,它的导数为y=2x。



e2345fe51c10329b9a189e3c559a37b8.png

如果我们初始化的点在x=1处,它的导函数值,也就是梯度值是2,为正,那就让它往左移一点,继续计算它的梯度值,若为正,就继续往左移。


如果我们初始化的点在x=-1处,该处的梯度值是-2,为负,那就让它往右移。


多元函数的逻辑也一样,先初始化一个点,也就是随便选择一个位置,计算它的梯度,然后往梯度相反的方向,每次移动一点点,直到达到停止条件


这个停止条件,可以是足够大的迭代步数,也可以是一个比较小的阈值,当两次迭代之间的差值小于该阈值时,认为梯度已经下降到最低点附近了。



4bd99e01371b9fa6bdb98ae95f01dac5.png

二元函数的梯度下降示例如上图(图片来自梯度下降),对于这种非凸函数,可能会出现这种情况:初始化的点不同,最后的结果也不同,也就是陷入局部最小值



e349858e838353d2b798ab1b21eed881.png

这种问题比较有效的解决方法,就是多取几个初始点。不过对于我们接下来讲的多元线性回归,以及后面要讲的逻辑回归,都不存在这个问题,因为他们的损失函数都是凸函数,有全局最小值。


用数学公式来描述梯度下降的步骤,就是:



b878f2e3f0ef6717be19d3fd8578af97.png

解释下公式含义:


  • θk 为k时刻的点坐标,θk+1 为下一刻要移动到的点的坐标,例如 θ0就代表初始化的点坐标, θ1就代表第一步到移动到的位置;
  • g代表梯度,前面有个负号,就代表梯度下降,即朝着梯度相反的反向移动;
  • α 被称为步长,用它乘以梯度值来控制每次移动的距离,这个值的设定也是一门学问,设定的过小,迭代的次数就会过多,设定的过大,容易一步跨太远,直接跳过了最小值。


f3d7e557356c696f4cea7032c98d723c.png

4.梯度下降法求解多元线性回归


回到前面的多元线性回归,我们用梯度下降算法求损失函数的最小值。


首先,求梯度,也就是前面我们已经给出的求偏导的公式:



16b7aa35599e722fd4606874fb5a8112.png

将梯度代入随机梯度下降公式:



4758dff064c9625a7b6ecfb74ef0b453.png

这个式子中,X矩阵和Y向量都是已知的,步长是人为设定的一个值,只有参数β是未知的,而每一步的


θ 是由β 决定的,也就是每一步的点坐标。


算法过程:


1. 初始化β 向量的值,即θ0 ,将其代入导函数得到当前位置的梯度;


2. 用步长α 乘以当前梯度,得到从当前位置下降的距离;


3. 更新θ1,其更新表达式为



f1b33857ae834e9fe45c415b9232a7c2.png

4. 重复以上步骤,直到更新到某个θk,达到停止条件,这个θk就是我们求解的参数向量。


参考链接:


深入浅出--梯度下降法及其实现


梯度下降与随机梯度下降概念及推导过程




推荐阅读
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • Android中高级面试必知必会,积累总结
    本文介绍了Android中高级面试的必知必会内容,并总结了相关经验。文章指出,如今的Android市场对开发人员的要求更高,需要更专业的人才。同时,文章还给出了针对Android岗位的职责和要求,并提供了简历突出的建议。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文介绍了Java工具类库Hutool,该工具包封装了对文件、流、加密解密、转码、正则、线程、XML等JDK方法的封装,并提供了各种Util工具类。同时,还介绍了Hutool的组件,包括动态代理、布隆过滤、缓存、定时任务等功能。该工具包可以简化Java代码,提高开发效率。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • GPT-3发布,动动手指就能自动生成代码的神器来了!
    近日,OpenAI发布了最新的NLP模型GPT-3,该模型在GitHub趋势榜上名列前茅。GPT-3使用的数据集容量达到45TB,参数个数高达1750亿,训练好的模型需要700G的硬盘空间来存储。一位开发者根据GPT-3模型上线了一个名为debuid的网站,用户只需用英语描述需求,前端代码就能自动生成。这个神奇的功能让许多程序员感到惊讶。去年,OpenAI在与世界冠军OG战队的表演赛中展示了他们的强化学习模型,在限定条件下以2:0完胜人类冠军。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 手把手教你使用GraphPad Prism和Excel绘制回归分析结果的森林图
    本文介绍了使用GraphPad Prism和Excel绘制回归分析结果的森林图的方法。通过展示森林图,可以更加直观地将回归分析结果可视化。GraphPad Prism是一款专门为医学专业人士设计的绘图软件,同时也兼顾统计分析的功能,操作便捷,可以帮助科研人员轻松绘制出高质量的专业图形。文章以一篇发表在JACC杂志上的研究为例,利用其中的多因素回归分析结果来绘制森林图。通过本文的指导,读者可以学会如何使用GraphPad Prism和Excel绘制回归分析结果的森林图。 ... [详细]
  • 建立分类感知器二元模型对样本数据进行分类
    本文介绍了建立分类感知器二元模型对样本数据进行分类的方法。通过建立线性模型,使用最小二乘、Logistic回归等方法进行建模,考虑到可能性的大小等因素。通过极大似然估计求得分类器的参数,使用牛顿-拉菲森迭代方法求解方程组。同时介绍了梯度上升算法和牛顿迭代的收敛速度比较。最后给出了公式法和logistic regression的实现示例。 ... [详细]
  • 前言:拿到一个案例,去分析:它该是做分类还是做回归,哪部分该做分类,哪部分该做回归,哪部分该做优化,它们的目标值分别是什么。再挑影响因素,哪些和分类有关的影响因素,哪些和回归有关的 ... [详细]
  • 导出功能protectedvoidbtnExport(objectsender,EventArgse){用来打开下载窗口stringfileName中 ... [详细]
author-avatar
lw65112779
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有