热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

深度学习分位数回归实现区间预测

点击“算法数据侠”,“星标”公众号第一时间获取最新推文与资源分享小侠客们周末愉快呀,又到了每周的学习时间,我是oubahe。今天我们探讨一下如何使用深度学习模型做到对目标值的区间预测。使用神经网络做回

点击“算法数据侠”,“星标”公众号

第一时间获取最新推文与资源分享

小侠客们周末愉快呀,又到了每周的学习时间,我是oubahe。今天我们探讨一下如何使用深度学习模型做到对目标值的区间预测使用神经网络做回归任务,我们使用MSE、MAE作为损失函数,最终得到的输出y通常会被近似为y的期望值,例如有两个样本:(x=1, y=3)和(x=1, y=2),那只用这两个样本训练模型,预测x=1时y的值就是2.5。但有些情况下目标值y的空间可能会比较大,只预测一个期望值并不能帮助我们做进一步的决策。我们想知道x=1时,y的值最小会是多少,最大会是多少,使用MSE、MAE这些损失函数来构建预测输出区间模型时候,往往需要对样本进行非常复杂的处理才能达到目的,而且因为数据的预处理需要加入很强的先验信息,建模效果肯定会打折扣,再一个如果数据规模比较大,那将会在数据预处理上浪费大量的时间。来吧,展示~

01

引言

这里介绍一个特殊的损失函数——分位数损失,利用分位数损失我们不需要对数据进行任何先验的处理,就可以轻松做到预测输出y的某一分位数水平值,例如5%分位数或95%分位数,利用这个输出很自然就完成预测输出范围的回归模型。分位数损失函数的表达式如下:

其中,γ是损失函数的参数,从实际意义上可以理解为是我们需要的分位数,这个损失函数从结构上看,就是以一定的概率γ惩罚预测值大于实际值,同时鼓励预测值小于实际值,这样的效果就是学得了目标y的γ分位数期望值。对于分位数损失函数的具体应用可以参考下边的例子。

02

分位数损失函数

我们这里以波士顿房价数据集为例理解一下分位数损失函数的效果。首先加载数据并进行分割,另外为了可视化方便,我们将x_test进行PCA降维并排序:

boston_data = load_boston()
x = boston_data['data']
y = boston_data['target']
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=77)
# PCA对数据进行降维处理,方便可视化
pca = PCA(n_compOnents=1)
x_test_after_pca = pca.fit_transform(x_test, y_test)
# 对数据进行排序
x_sort_id = np.argsort(x_test, axis=None)
x_test = x_test[x_sort_id]
y_test = y_test[x_sort_id]
x_for_visual = x_test_after_pca[x_sort_id]

第二步就是损失函数的定义,在Tensorflow2.0中没有分位数函数的定义,可以根据公式来进行定义:

def quantile_loss(y_pred, y_true, r=0.5):
greater_mask = K.cast((y_true <= y_pred), 'float32')
smaller_mask = K.cast((y_true > y_pred), 'float32')
return K.sum((r-1)*K.abs(smaller_mask*(y_true-y_pred)), -1)+K.sum(r*K.abs(greater_mask*(y_true-y_pred)), -1)

在这里,我们给出一种更简洁的定义方式:

def tilted_loss(q, y_true, y_pred):
e = (y_true-y_pred)
return K.mean(K.maximum(q*e, (q-1)*e), axis=-1)

进一步地,我们构建一个最简单的模型:

def gen_model(ipt_dim):
l1 = Input(shape=(ipt_dim,))
l2 = Dense(10, activation='relu', kernel_initializer=glorot_normal(), bias_initializer=zeros())(l1)
l3 = Dense(5, activation='relu', kernel_initializer=glorot_normal(), bias_initializer=zeros())(l2)
l4 = Dense(1, activation='relu', kernel_initializer=glorot_normal(), bias_initializer=zeros())(l3)
m_model = Model(inputs=l1, outputs=l4)
return m_model

调用模型,并且对结果进行可视化:

plt.figure()
plt.scatter(x_for_visual, y_test, label='actual')
q_list = [0.1, 0.5, 0.9]

for quantile in q_list:
model = gen_model(in_dims)
model.compile(loss=lambda y_t, y_p: tilted_loss(quantile, y_t, y_p), optimizer='adam')
model.fit(x_train, y_train, epochs=5, batch_size=8, verbose=1)
y_ = model.predict(x_test)
plt.plot(x_for_visual, y_, label=quantile)

plt.legend()
plt.show()

可视化的最终结果如下图所示,我们可以看到学习到的三个档位的分位数回归模型。其中0.5的与普通MAE回归结果是等价的,只不过0.5预测的是中位数而MAE预测的是期望值;0.1和0.9的两条曲线可以作为预测结果的上下界,能够包含其中80%的数据结果,如果我们追求更高的区间置信度,可以选择更低的下界分位数和更高的上界分位数。

03

总结

以上就是深度学习分位数回归实现区间预测的所有内容了,一般来说,大家在深度回归模型中分别使用1/4中位数和3/4中位数作为上下界分位数损失函数,得到的上下界分位数回归结果是具有95%置信度的。

在回归任务中,我们可以轻松构建能够预测输出值范围的模型,并且不依赖对数据的先验处理,是一种非常高效的方法;另外分位数损失函数的实现,推荐使用文中的第二种方法,涉及的计算步骤更少,效率更高;分位数损失函数与MAE损失类似,是一种线性的损失函数,在loss在0附近的区间内同样存在导数不连续的问题,各位聪明的小侠客如果能有简单的方法可以规避这个缺点,则会更加好用好啦,今天的内容就到这里啦,我是oubahe,我们下次再见咯~
码字虽少,原创不易。分享是快乐的源泉,来个素质三连 —>点击左下角分享 —> 右下角点赞+在看本文,可以汇聚好运气召唤神龙哟~


推荐阅读
  • 本文将详细探讨 Java 中提供的不可变集合(如 `Collections.unmodifiableXXX`)和同步集合(如 `Collections.synchronizedXXX`)的实现原理及使用方法,帮助开发者更好地理解和应用这些工具。 ... [详细]
  • 本文介绍了如何在 C# 和 XNA 框架中实现一个自定义的 3x3 矩阵类(MMatrix33),旨在深入理解矩阵运算及其应用场景。该类参考了 AS3 Starling 和其他相关资源,以确保算法的准确性和高效性。 ... [详细]
  • ▶书中第四章部分程序,包括在加上自己补充的代码,有边权有向图的邻接矩阵,FloydWarshall算法可能含负环的有边权有向图任意两点之间的最短路径●有边权有向图的邻接矩阵1 ... [详细]
  • Java 实现二维极点算法
    本文介绍了一种使用 Java 编程语言实现的二维极点算法。该算法用于从一组二维坐标中筛选出极点,适用于需要处理几何图形和空间数据的应用场景。文章不仅详细解释了算法的工作原理,还提供了完整的代码示例。 ... [详细]
  • 本文介绍如何使用MFC和ADO技术调用SQL Server中的存储过程,以查询指定小区在特定时间段内的通话统计数据。通过用户界面选择小区ID、开始时间和结束时间,系统将计算并展示小时级的通话量、拥塞率及半速率通话比例。 ... [详细]
  • 本文探讨了如何通过预处理器开关选择不同的类实现,并解决在特定情况下遇到的链接器错误。 ... [详细]
  • This post discusses an issue encountered while using the @name annotation in documentation generation, specifically regarding nested class processing and unexpected output. ... [详细]
  • 本文探讨了如何在Classic ASP中实现与PHP的hash_hmac('SHA256', $message, pack('H*', $secret))函数等效的哈希生成方法。通过分析不同实现方式及其产生的差异,提供了一种使用Microsoft .NET Framework的解决方案。 ... [详细]
  • 本文详细介绍如何使用 Apache Spark 执行基本任务,包括启动 Spark Shell、运行示例程序以及编写简单的 WordCount 程序。同时提供了参数配置的注意事项和优化建议。 ... [详细]
  • 本文档汇总了Python编程的基础与高级面试题目,涵盖语言特性、数据结构、算法以及Web开发等多个方面,旨在帮助开发者全面掌握Python核心知识。 ... [详细]
  • 深入解析Hadoop的核心组件与工作原理
    本文详细介绍了Hadoop的三大核心组件:分布式文件系统HDFS、资源管理器YARN和分布式计算框架MapReduce。通过分析这些组件的工作机制,帮助读者更好地理解Hadoop的架构及其在大数据处理中的应用。 ... [详细]
  • 本文探讨了Hive作业中Map任务数量的确定方式,主要涉及HiveInputFormat和CombineHiveInputFormat两种InputFormat的分片计算逻辑。通过调整相关参数,可以有效控制Map任务的数量,进而优化Hive作业的性能。 ... [详细]
  • ML学习笔记20210824分类算法模型选择与调优
    3.模型选择和调优3.1交叉验证定义目的为了让模型得精度更加可信3.2超参数搜索GridSearch对K值进行选择。k[1,2,3,4,5,6]循环遍历搜索。API参数1& ... [详细]
  • 在Python编程学习过程中,许多初学者常遇到各种功能实现难题。虽然这些问题往往并不复杂,但找到高效解决方案却能显著提升编程效率。本文将介绍一个名为‘30-seconds-of-python’的优质资源,帮助大家快速掌握实用的Python技巧。 ... [详细]
  • MapReduce原理是怎么剖析的
    这期内容当中小编将会给大家带来有关MapReduce原理是怎么剖析的,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。1 ... [详细]
author-avatar
花颖年华
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有