热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

如何优雅的拟合非线性曲线

这次真的和阿猪老师学到了,深度学习降维打击数据拟合竟然是这么的高效!某天下班,小洛在微信上忽然呼叫我:我一看好家伙ÿ

这次真的和阿猪老师学到了,深度学习降维打击数据拟合竟然是这么的高效!


某天下班,小洛在微信上忽然呼叫我:

97d3829d01204d858700c0033ef02dfa.png

bbacd4f7b4c9c1bffbaffb7a9d5c3391.png

我一看好家伙,这不是手到擒来?这也太简单了,直接开整:

9d4bd33a0c748e5709b8344eae4b0172.png

一、curve_fit函数拟合

众所周知,scipy就可以进行曲线拟合,直接开整:

2a07897220e7242f309ee1771805371e.png

输出效果:

5aae505fd25d1dda5cac4128befd0460.png

某位同学一头雾水:我要的是非线性的函数,为啥出来的是一条直线?还相差这么远?

二、这个表达式是合理的吗?

优化的第一步:可以看到函数内的-b*(x-a) 首先就可以被简化成bx+a,来把参数合并掉(冷知识,在这里其实无所谓参数的正负号,优化方法会自动去进行拟合的)

输出:

a22a814d7d286cba8c95b1bd980210e0.png

这一个简单的改动,直接让mse从8e-8下降到3e-9,约20倍。

0906e2a6033e8f866ee7b81f52a7a4ed.png

问:为什么拟合的曲线还是一条直线,不是小洛同学想要的那条曲线?

答:因为curve_fit函数使用最小二乘法对函数进行拟合,适用的范围仅为使用线性参数组合的非线性函数,例如:

a82a72fe3b729c044de00fdca828b09e.png

其中的a0、a1即为可以拟合的参数,θ0(x)则是非线性的某些函数,而小洛给出的sigmoid的函数形式并不是这样,因此用curve_fit拟合出来的效果并不让人满意。

高中生涯渐渐远去,相信同学们的数学知识忘得也差不多了,在这里简单地提一下函数拟合中需要注意的问题:

1、函数的零点、缩放参数

例如sigmoid的原始形式为:

043da6e99182f1ab70ce64678dd76f9f.png

在x=0的时候y取得0.5. 如果想要在x轴上做左右平移,就需要把x替换成x+a的形式,成为:

c2f26971937cbb32031f7eee26020099.png

而在y轴上做上下平移,同样地要在y替换成y+b的形式;还有横向、纵向的拉伸,也对应了x需要替换成c*x, y需要替换成d*y 。因为期待被拟合的参数自行决定正负,所以放分子和分母都没关系;经过项式调整,我们形成了最终表达式,对于原先sigmoid函数而言,这个格式可以支持平移和缩放的形式:

4e231ea62556303a77e06a757e5b5345.png

我们把原先公式改写为上面的公式套进去,再用curve_fit优化输出:

b0c57b06a16ee46610271b43758a3491.png

可以看到mse进一步下降到2e-9,也终于有了一点点的曲线样子,然鹅还有很大提升空间。

2、函数的定义域

523e15d6c044f618a8eda5c9ca0e8235.png

如果所选用的函数对定义域有要求,比如选用典型的对数函数log(x) ,y=log(ax+b)的时候。要复习一下数学知识,知道对数函数的定义域(也就是此处ax+b的范围),要求是大于0的,但是我们给出的这个形式并不能保证ax+b一定大于0,就会超出计算范围,各大计算库此时会返回nan,此时可以在ax+b外面加一个绝对值。

3、函数的值域

小洛要拟合的sigmoid除了非线性的特点,还要注意一个问题:它的值域,也就是y能够抵达的范围。这里如果我们确定定义域的范围是固定的范围内还好,就可以把最终拟合的值域全部算出来。原始sigmoid的值域是0~1范围内,因此就算给y加上了缩放,最终也只能到达0~c的范围内

在寻常的非线性函数的拟合里,如果需要继续使用curve_fit,就需要将其转化为线性的形式,例如指数函数:

68f52f8e0c30a5e028383d48d341cc27.png

它的拟合可以取对数,让其对lny~x的关系进行拟合:

1149d4292dc8f47fe7b07fc979b197ca.png

看来,curve_fit函数并不能用来很好的拟合sigmoid函数,即使表达式已经已经优化到最终格式,拟合出的却还是只有微微的曲线,并不是sigmoid曲线的样子。要强行用curve_fit的话只能做更复杂的变换去依靠中间表达式去间接拟合。

2044c861e0f389c86858763eb6d67b49.png

(那么,能不能再给力一点,这个好复杂啊不懂.jpg)

三、非线性函数拟合的通用解决法

当然对于很多同学们来说,会发现继续用curve_fit函数,把非线性函数转化为若干参数与函数的乘法逻辑过于复杂,超出了自己的数学能力,那能不能有一个更通用、简单的方式去解决呢?

当然有!

作为一名工作在人工智障领域的算法工程师,现在的名词“深度学习”正是为了解决超级复杂的非线性问题而产生的,深度学习的模型正是非线性的典型代表。这个非线性已经叠加到了人类无法用手动去对公式进行优化的复杂度。

非线性导致深度学习在处理深层次的信息上,性能非常优异,在图像识别、图文理解这些领域,都得到了超越人类的特性;我们现在看到的各种商品、图文、短视频,最终也是靠深度学习,计算出这些单品的向量,再和我们自身的兴趣向量做运算,影响了搜索、推荐时的排序。

ceb8933e59ccaea62f12f64dec659483.png

(深度学习.blingbling.jpg)

小洛这次仅仅只要拟合这个sigmoid非线性函数,但如果下次是其他曲线的函数的话,就又要经过其他的手工的表达式变化与复杂的构造,这样不行!

在这里,我们可以使用深度学习里的优化原理,找到一个针对非线性模型拟合的通用解决方法,这里使用torch进行优化。

首先我们对函数进行少量的改动,其中包含了我们前面说的平移、缩放的内容,因此带来了4个参数:

8bf7b6db4175fc171ca87d2000d22c8e.png

所有深度学习优化的原理都是一样,其中核心规则很简单:

1、对计算进行求(偏)导

例如y=a+bx 我们要知道每个参数对于最终输入的y产生了多大的作用,就是求参数a对于y的偏导(理解为此处的a的变化,会带来多大的y的变化)

2.使用链式法则进行导数的传递

例如复合函数:

0c4bec566e6fcb80b3b8afcb2f89f5b0.png

是复合函数:

369d068f1c5a44804aeb561c6268d30f.png

的传递,链式法则告诉我们,这样的复合函数求导遵循规则:

e097997f6541cfad169311ec4191eaf4.png

参数a对于y的导数,等于复合函数g对y的导数乘以 a对函数g的导数;以此类推,嵌套的复合函数就这样依次传递下来即可;

3.知道真值和预测值的差异

(小洛此处定义为mse),我们平常意义上也即定义为损失。可以由此计算各个参数的导数,那么参数们沿着最小化损失的方向迈进,就可以让损失降低。这在深度学习里通过优化器的方式来进行。

这几个规则组合到一起,也就是传说中的SGD梯度下降法:

16f02c545d995beb7bd326a5f8016bcf.png

当然实际的算法中使用的是改进的一个优化器Adam,我们加一点点细节,写出如下代码(其中用来优化函数参数的代码甚至也就10行):

70b61024dc7c471adc8bec6dff015d2c.png

输出效果非常美好,完美回答了小洛的问题:

d93e513032c3c0bc13aaaf4d00eda490.png

四、为什么效果还差一点?

虽然目前已经完美的回答了小洛的sigmoid拟合问题。但实际使用上,可能大家还是会遇到一些不完美的情况。一种情况是前置没有设计好需要拟合的函数,导致函数的作用域和定义域出现问题,与实际的数据不匹配;另一种则是代码中涉及的超参了;

这里有几个比较魔法的数字,我们来讲一下这些魔法的原理:

1、为什么有个bigger参数

因为x和y的范围都很小,尤其是y,在0.01范围内,这个范围假如一个函数的预测始终输出0.012,最后的mse一平方,会直接落到1e-7范围以内,这个数字太小了,以至于低于了许多计算中eps的范畴(一个例子是为了避免除以0,除法会成为1/(x+eps)),这会导致函数难以进行有效的优化。

2、Learning rate为什么是0.002

这个叫做学习率的参数,在实际的应用场景中也需要在一定的范围内进行调整。前面我们说到,这个是优化器根据参数的梯度,往前前进的步长,因此这个步长跟实际的非线性函数的特点相关;举个例子:我们知道x^2函数的导数是x,所以在x=2处导数是2,步长决定了我要往最小化损失的地方迈多远,比如此时如果学习率是2,下一次我会迈往x=2-2*2=-2 的位置,这样就很难迈向函数的最小值部分了。

3、StepLR起到一个什么作用?为什么是0.8

一般来说,优化到了后面,需要一个更保守的步长,这个可以通过观察loss的下降情况来进行调整。

比如小洛遇到的另一个案例问题:(和上面的案例不同,这次是要拟合log函数曲线)

5731733b86b3657de849b213c146bb4f.png

她这个拟合出来的曲线就不是很棒棒。这个一来可以对上面的超参学习率进行调整。二来还可以调整下公式,看看是不是在函数本身之外,再加个线性的部分比如y=log(ax+b)+cx+d之类的。

8ff93a0801ddf78076066ba69def6149.png

聪明的小洛已经很快地学会了,同学们也都可以一起用起来这个可以适用于任意数据函数拟合的万能方法啦!

End


推荐阅读
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • 【观察】中国产业AI化的破局之路:加速算力释放与生态合作共赢
    申耀的科技观察读懂科技,赢取未来!电影《斗士》中,有这么一句台词令人印象深刻:“知道路要怎么走,和走上这条路& ... [详细]
  • 分类与聚类
    一:分类1:定义分类其实是从特定的数据中挖掘模式,做出判断的过程。分类是在一群已经知道类别标号的样本中,训练一种分类器 ... [详细]
  • 快过HugeCTR:用OneFlow轻松实现大型推荐系统引擎
    一、简介Wide&DeepLearning(以下简称WDL)是解决点击率预估(CTRPrediction) ... [详细]
  • 学习SLAM的女生,很酷
    本文介绍了学习SLAM的女生的故事,她们选择SLAM作为研究方向,面临各种学习挑战,但坚持不懈,最终获得成功。文章鼓励未来想走科研道路的女生勇敢追求自己的梦想,同时提到了一位正在英国攻读硕士学位的女生与SLAM结缘的经历。 ... [详细]
  • 近年来,大数据成为互联网世界的新宠儿,被列入阿里巴巴、谷歌等公司的战略规划中,也在政府报告中频繁提及。据《大数据人才报告》显示,目前全国大数据人才仅46万,未来3-5年将出现高达150万的人才缺口。根据领英报告,数据剖析人才供应指数最低,且跳槽速度最快。中国商业结合会数据剖析专业委员会统计显示,未来中国基础性数据剖析人才缺口将高达1400万。目前BAT企业中,60%以上的招聘职位都是针对大数据人才的。 ... [详细]
  • 建立分类感知器二元模型对样本数据进行分类
    本文介绍了建立分类感知器二元模型对样本数据进行分类的方法。通过建立线性模型,使用最小二乘、Logistic回归等方法进行建模,考虑到可能性的大小等因素。通过极大似然估计求得分类器的参数,使用牛顿-拉菲森迭代方法求解方程组。同时介绍了梯度上升算法和牛顿迭代的收敛速度比较。最后给出了公式法和logistic regression的实现示例。 ... [详细]
  • OCR:用字符识别方法将形状翻译成计算机文字的过程Matlab:商业数学软件;CUDA:CUDA™是一种由NVIDIA推 ... [详细]
  • 3年半巨亏242亿!商汤高估了深度学习,下错了棋?
    转自:新智元三年半研发开支近70亿,累计亏损242亿。AI这门生意好像越来越不好做了。近日,商汤科技已向港交所递交IPO申请。招股书显示& ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
  • Two Sigma人均22万英镑~
    近期原创文章: ... [详细]
  • 干货 | 携程AI推理性能的自动化优化实践
    作者简介携程度假AI研发团队致力于为携程旅游事业部提供丰富的AI技术产品,其中性能优化组为AI模型提供全方位的优化方案,提升推理性能降低成本࿰ ... [详细]
  • 基于深度学习的遥感应用
    文章目录深度学习的发展过程深度学习在遥感中的应用基于深度学习的遥感样例库建设基于深度学习的遥感影像目标及场景检索基于深度学习的建筑物提取基于深度学习的密集建筑物自动检测基于深度学习 ... [详细]
  • 开源真香 离线识别率高 Python 人脸识别系统
    本文主要介绍关于python,人工智能,计算机视觉的知识点,对【开源真香离线识别率高Python人脸识别系统】和【】有兴趣的朋友可以看下由【000X000】投稿的技术文章,希望该技术和经验能帮到 ... [详细]
  • 脑机接口和卷积神经网络的初学指南(一)
    脑机接口和卷积神经网络的初学指南(一) ... [详细]
author-avatar
phpxiaohui
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有