热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Python实现-最小二乘回归树RTree

最小二乘回归树生成给定数据集D{,(xi,yi),},n维的实例xi有n个特征,通过选择一个特征j和该特征取值范围内的一个分割值s,将该组数据集分割成两部分:R

最小二乘回归树生成

给定数据集D={...,(xi,yi),...}n维的实例xi有n个特征, 通过选择一个特征j和该特征取值范围内的一个分割值s,将该组数据集分割成两部分: 

R1(j,s)={x | x[j]s}  ,  R2(j,s)={x | x[j]>s}

然后计算两个区域上所对应的输出的平均值作为该节点的输出: 
c1=average(yi|xiR1),  c2=average(yi|xiR2)

之后计算平方误差和: 
errtotal=xiR1(yic1)2+xiR2(yic1)2

j,s 的选择要使得 errtotal 最小, 本文采用的办法是用二重循环遍历所有特征和该特征下所有可能的分割点,最后找到使得 errtotal 最小的 j,s . 将数据集D作为根节点, 利用求得的 j,s 将数据集分成两个子集, 生成两个叶子节点, 并且把数据子集分配给两个叶子节点. 对叶子节点重复以上行为, 直到满足停止条件或者使得训练误差达到0, 这样就生成一颗二叉树, 当输入一个新实例之后, 根据每个节点上的 j,s 将实例点逐层划分到分到子节点, 直到遇到叶子节点, 将该叶子节点的输出值作为输出.

下面给出Python实现代码

import numpy as np
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D

#定义一个简单的树结构
class RTree:
def __init__(self,data,z,slicedIdx):
self.data =data
self.z =z
self.isLeaf = True
self.slicedIdx = slicedIdx #节点上只保存数据的序号,不保存数据子集,节约内存
self.left =None
self.right = None
self.output = np.mean(z[slicedIdx])
self.j = None
self.s = None
#本节点所带的子数据如果大于1个,则生成两个叶子节点,本节点不再是叶子节点
def grow(self):
if len(self.slicedIdx)>1:
j,s,_ = bestDivi(self.data,self.z,self.slicedIdx)
leftIdx, rightIdx = [], []
for i in self.slicedIdx:
if self.data[i,j] leftIdx.append(i)
else:
rightIdx.append(i)
self.isLeaf =False
self.left = RTree(self.data,self.z,leftIdx)
self.right = RTree(self.data,self.z,rightIdx)
self.j=j
self.s=s
def err(self):
return np.mean((self.z[self.slicedIdx]-self.output)**2)

#计算平方差
def squaErr(data,output,slicedIdx,j,s):
#挑选数据子集
region1 = []
region2 = []
for i in slicedIdx:
if data[i,j] region1.append(i)
else:
region2.append(i)
#计算子集上的平均输出
c1 = np.mean(output[region1])
err1 = np.sum((output[region1]-c1)**2)

c2 = np.mean(output[region2])
err2 = np.sum((output[region2]-c2)**2)
#返回平方差
return err1+err2

#用于选择最佳划分属性j和最切分点s
def bestDivi(data,z,slicedIdx):
min_j = 0
sortedValue = np.sort(data[slicedIdx][:,min_j])
min_s = (sortedValue[0]+sortedValue[1])/2
err = squaErr(data,z,slicedIdx,min_j,min_s)
#遍历属性
for j in range(data.shape[1]):
#产生某个属性值的分割点集合
sortedValue = np.sort(data[slicedIdx][:,j])
sliceValue = (sortedValue[1:]+sortedValue[:-1])/2
for s in sliceValue:
errNew = squaErr(data,z,slicedIdx,j,s)
if errNew err = errNew
min_j = j
min_s = s

return min_j, min_s, err

#更新树
def updateTree(tree):
if tree.isLeaf:
tree.grow()
else:
updateTree(tree.left)
updateTree(tree.right)

#预测一个数据点的输出
def predict(single_data,init_tree):
tree = init_tree
while True:
if tree.isLeaf:
return tree.output
else:
if single_data[tree.j] tree = tree.left
else:
tree = tree.right

#利用z=x+y+noise 人为生成一个数据集, 具有2个特征
n_samples = 300
points = np.random.rand(n_samples,2)
z = points[:,0]+points[:,1] + 0.2*(np.random.rand(n_samples)-0.5)
#生成根节点
root = RTree(points,z,range(n_samples))
#进行五次生长, 观测每次生长过后的拟合效果
for ii in range(5):
updateTree(root)
z_predicted = np.array([predict(p,root) for p in points])
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111,projection="3d")
ax.scatter(points[:,0],points[:,1],z)
ax.scatter(points[:,0],points[:,1],z_predicted)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

可见随着树的加深, 训练误差逐渐减小, 可见如果树足够深, 训练误差会达到0. 值得指出得是, 对于二分类问题, 树的VC维是无穷, 规模有限的数据集总可以被打散(scatter).

参考文献:

[1]李航. 统计学习方法.


推荐阅读
  • 在OpenCV 3.1.0中实现SIFT与SURF特征检测
    本文介绍如何在OpenCV 3.1.0版本中通过Python 2.7环境使用SIFT和SURF算法进行图像特征点检测。由于这些高级功能在OpenCV 3.0.0及更高版本中被移至额外的contrib模块,因此需要特别处理才能正常使用。 ... [详细]
  • 深入解析层次聚类算法
    本文详细介绍了层次聚类算法的基本原理,包括其通过构建层次结构来分类样本的特点,以及自底向上(凝聚)和自顶向下(分裂)两种主要的聚类策略。文章还探讨了不同距离度量方法对聚类效果的影响,并提供了具体的参数设置指导。 ... [详细]
  • 在执行市场篮子分析时遇到性能瓶颈,尤其是在设定频繁项集的支持度阈值为1%时。本文探讨了如何通过调整代码和参数来提高分析效率。 ... [详细]
  • 本教程介绍如何在C#中通过递归方法将具有父子关系的列表转换为树形结构。我们将详细探讨如何处理字符串类型的键值,并提供一个实用的示例。 ... [详细]
  • 线段树详解与实现
    本文详细介绍了线段树的基本概念及其在编程竞赛中的应用,并提供了一个具体的线段树实现代码示例。 ... [详细]
  • 本文详细介绍如何在华为鲲鹏平台上构建和使用适配ARM架构的Redis Docker镜像,解决常见错误并提供优化建议。 ... [详细]
  • 编译原理中的语法分析方法探讨
    本文探讨了在编译原理课程中遇到的复杂文法问题,特别是当使用SLR(1)文法时遇到的多重规约与移进冲突。文章讨论了可能的解决策略,包括递归下降解析、运算符优先级解析等,并提供了相关示例。 ... [详细]
  • 题目编号:2049 [SDOI2008]Cave Exploration。题目描述了一种动态图操作场景,涉及三种基本操作:断开两个节点间的连接(destroy(a,b))、建立两个节点间的连接(connect(a,b))以及查询两节点是否连通(query(a,b))。所有操作均确保图中无环存在。 ... [详细]
  • 本文详细介绍了HashSet类,它是Set接口的一个实现,底层使用哈希表(实际上是HashMap实例)。HashSet不保证元素的迭代顺序,并且是非线程安全的。 ... [详细]
  • android开发分享荐                                                         Android思维导图布局:效果展示及使用方法
    思维导图布局的前身是树形布局,对树形布局基本使用还不太了解的朋友可以先看看我写的树形布局系列教程,了解了树形布局的使用方法后再来阅读本文章。先睹为快来看看效果吧,横向效果如下:纵向 ... [详细]
  • 本文介绍了如何在Linux系统中获取库源码,并在从源代码编译软件时收集所需的依赖项列表。 ... [详细]
  • 本文通过一个简单的示例,展示如何使用ASP技术生成HTML文件。示例包括两个页面:首页index.htm和处理页面send.asp。 ... [详细]
  • 本文介绍了如何使用开源工具ChkBugReport来解析和分析Android设备的Bugreport。ChkBugReport能够将复杂的Bugreport转换为易于阅读的HTML报告,并提供详细的图表和分析结论。 ... [详细]
  • 开发笔记:树的浅析与实现 ... [详细]
  • 理解GiST索引的空间构造原理
    通过空间思维解析GiST索引的构建方式及其在空间数据检索中的应用。 ... [详细]
author-avatar
13578945682a_699
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有