热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Python实现-最小二乘回归树RTree

最小二乘回归树生成给定数据集D{,(xi,yi),},n维的实例xi有n个特征,通过选择一个特征j和该特征取值范围内的一个分割值s,将该组数据集分割成两部分:R

最小二乘回归树生成

给定数据集D={...,(xi,yi),...}n维的实例xi有n个特征, 通过选择一个特征j和该特征取值范围内的一个分割值s,将该组数据集分割成两部分: 

R1(j,s)={x | x[j]s}  ,  R2(j,s)={x | x[j]>s}

然后计算两个区域上所对应的输出的平均值作为该节点的输出: 
c1=average(yi|xiR1),  c2=average(yi|xiR2)

之后计算平方误差和: 
errtotal=xiR1(yic1)2+xiR2(yic1)2

j,s 的选择要使得 errtotal 最小, 本文采用的办法是用二重循环遍历所有特征和该特征下所有可能的分割点,最后找到使得 errtotal 最小的 j,s . 将数据集D作为根节点, 利用求得的 j,s 将数据集分成两个子集, 生成两个叶子节点, 并且把数据子集分配给两个叶子节点. 对叶子节点重复以上行为, 直到满足停止条件或者使得训练误差达到0, 这样就生成一颗二叉树, 当输入一个新实例之后, 根据每个节点上的 j,s 将实例点逐层划分到分到子节点, 直到遇到叶子节点, 将该叶子节点的输出值作为输出.

下面给出Python实现代码

import numpy as np
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D

#定义一个简单的树结构
class RTree:
def __init__(self,data,z,slicedIdx):
self.data =data
self.z =z
self.isLeaf = True
self.slicedIdx = slicedIdx #节点上只保存数据的序号,不保存数据子集,节约内存
self.left =None
self.right = None
self.output = np.mean(z[slicedIdx])
self.j = None
self.s = None
#本节点所带的子数据如果大于1个,则生成两个叶子节点,本节点不再是叶子节点
def grow(self):
if len(self.slicedIdx)>1:
j,s,_ = bestDivi(self.data,self.z,self.slicedIdx)
leftIdx, rightIdx = [], []
for i in self.slicedIdx:
if self.data[i,j] leftIdx.append(i)
else:
rightIdx.append(i)
self.isLeaf =False
self.left = RTree(self.data,self.z,leftIdx)
self.right = RTree(self.data,self.z,rightIdx)
self.j=j
self.s=s
def err(self):
return np.mean((self.z[self.slicedIdx]-self.output)**2)

#计算平方差
def squaErr(data,output,slicedIdx,j,s):
#挑选数据子集
region1 = []
region2 = []
for i in slicedIdx:
if data[i,j] region1.append(i)
else:
region2.append(i)
#计算子集上的平均输出
c1 = np.mean(output[region1])
err1 = np.sum((output[region1]-c1)**2)

c2 = np.mean(output[region2])
err2 = np.sum((output[region2]-c2)**2)
#返回平方差
return err1+err2

#用于选择最佳划分属性j和最切分点s
def bestDivi(data,z,slicedIdx):
min_j = 0
sortedValue = np.sort(data[slicedIdx][:,min_j])
min_s = (sortedValue[0]+sortedValue[1])/2
err = squaErr(data,z,slicedIdx,min_j,min_s)
#遍历属性
for j in range(data.shape[1]):
#产生某个属性值的分割点集合
sortedValue = np.sort(data[slicedIdx][:,j])
sliceValue = (sortedValue[1:]+sortedValue[:-1])/2
for s in sliceValue:
errNew = squaErr(data,z,slicedIdx,j,s)
if errNew err = errNew
min_j = j
min_s = s

return min_j, min_s, err

#更新树
def updateTree(tree):
if tree.isLeaf:
tree.grow()
else:
updateTree(tree.left)
updateTree(tree.right)

#预测一个数据点的输出
def predict(single_data,init_tree):
tree = init_tree
while True:
if tree.isLeaf:
return tree.output
else:
if single_data[tree.j] tree = tree.left
else:
tree = tree.right

#利用z=x+y+noise 人为生成一个数据集, 具有2个特征
n_samples = 300
points = np.random.rand(n_samples,2)
z = points[:,0]+points[:,1] + 0.2*(np.random.rand(n_samples)-0.5)
#生成根节点
root = RTree(points,z,range(n_samples))
#进行五次生长, 观测每次生长过后的拟合效果
for ii in range(5):
updateTree(root)
z_predicted = np.array([predict(p,root) for p in points])
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111,projection="3d")
ax.scatter(points[:,0],points[:,1],z)
ax.scatter(points[:,0],points[:,1],z_predicted)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

可见随着树的加深, 训练误差逐渐减小, 可见如果树足够深, 训练误差会达到0. 值得指出得是, 对于二分类问题, 树的VC维是无穷, 规模有限的数据集总可以被打散(scatter).

参考文献:

[1]李航. 统计学习方法.


推荐阅读
  • 本文详细介绍了如何在 Windows 环境下使用 node-gyp 工具进行 Node.js 本地扩展的编译和配置,涵盖从环境搭建到代码实现的全过程。 ... [详细]
  • 尽管深度学习带来了广泛的应用前景,其训练通常需要强大的计算资源。然而,并非所有开发者都能负担得起高性能服务器或专用硬件。本文探讨了如何在有限的硬件条件下(如ARM CPU)高效运行深度神经网络,特别是通过选择合适的工具和框架来加速模型推理。 ... [详细]
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • 目录一、salt-job管理#job存放数据目录#缓存时间设置#Others二、returns模块配置job数据入库#配置returns返回值信息#mysql安全设置#创建模块相关 ... [详细]
  • 深入解析JMeter中的JSON提取器及其应用
    本文详细介绍了如何在JMeter中使用JSON提取器来获取和处理API响应中的数据。特别是在需要将一个接口返回的数据作为下一个接口的输入时,JSON提取器是一个非常有用的工具。 ... [详细]
  • 在 Flutter 开发过程中,开发者经常会遇到 Widget 构造函数中的可选参数 Key。对于初学者来说,理解 Key 的作用和使用场景可能是一个挑战。本文将详细探讨 Key 的概念及其应用场景,并通过实例帮助你更好地掌握这一重要工具。 ... [详细]
  • 本文将详细探讨Linux pinctrl子系统的各个关键数据结构,帮助读者深入了解其内部机制。通过分析这些数据结构及其相互关系,我们将进一步理解pinctrl子系统的工作原理和设计思路。 ... [详细]
  • Git管理工具SourceTree安装与使用指南
    本文详细介绍了Git管理工具SourceTree的安装、配置及团队协作方案,旨在帮助开发者更高效地进行版本控制和项目管理。 ... [详细]
  • 本文详细介绍了 Java 中的 org.apache.hadoop.registry.client.impl.zk.ZKPathDumper 类,提供了丰富的代码示例和使用指南。通过这些示例,读者可以更好地理解如何在实际项目中利用 ZKPathDumper 类进行注册表树的转储操作。 ... [详细]
  • Kubernetes 持久化存储与数据卷详解
    本文深入探讨 Kubernetes 中持久化存储的使用场景、PV/PVC/StorageClass 的基本操作及其实现原理,旨在帮助读者理解如何高效管理容器化应用的数据持久化需求。 ... [详细]
  • 在创建新的Android项目时,您可能会遇到aapt错误,提示无法打开libstdc++.so.6共享对象文件。本文将探讨该问题的原因及解决方案。 ... [详细]
  • 通过生动有趣的顺口溜,帮助孩子们轻松记住各种运动、服饰、自然景物、星期和食物的英文单词。这些口诀不仅朗朗上口,还能加深对单词的记忆。 ... [详细]
  • 本题探讨了在大数据结构背景下,如何通过整体二分和CDQ分治等高级算法优化处理复杂的时间序列问题。题目设定包括节点数量、查询次数和权重限制,并详细分析了解决方案中的关键步骤。 ... [详细]
  • 2018-2019学年第六周《Java数据结构与算法》学习总结
    本文总结了2018-2019学年第六周在《Java数据结构与算法》课程中的学习内容,重点介绍了非线性数据结构——树的相关知识及其应用。 ... [详细]
  • Nginx 反向代理与负载均衡实验
    本实验旨在通过配置 Nginx 实现反向代理和负载均衡,确保从北京本地代理服务器访问上海的 Web 服务器时,能够依次显示红、黄、绿三种颜色页面以验证负载均衡效果。 ... [详细]
author-avatar
13578945682a_699
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有