热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Python实现-最小二乘回归树RTree

最小二乘回归树生成给定数据集D{,(xi,yi),},n维的实例xi有n个特征,通过选择一个特征j和该特征取值范围内的一个分割值s,将该组数据集分割成两部分:R

最小二乘回归树生成

给定数据集D={...,(xi,yi),...}n维的实例xi有n个特征, 通过选择一个特征j和该特征取值范围内的一个分割值s,将该组数据集分割成两部分: 

R1(j,s)={x | x[j]s}  ,  R2(j,s)={x | x[j]>s}

然后计算两个区域上所对应的输出的平均值作为该节点的输出: 
c1=average(yi|xiR1),  c2=average(yi|xiR2)

之后计算平方误差和: 
errtotal=xiR1(yic1)2+xiR2(yic1)2

j,s 的选择要使得 errtotal 最小, 本文采用的办法是用二重循环遍历所有特征和该特征下所有可能的分割点,最后找到使得 errtotal 最小的 j,s . 将数据集D作为根节点, 利用求得的 j,s 将数据集分成两个子集, 生成两个叶子节点, 并且把数据子集分配给两个叶子节点. 对叶子节点重复以上行为, 直到满足停止条件或者使得训练误差达到0, 这样就生成一颗二叉树, 当输入一个新实例之后, 根据每个节点上的 j,s 将实例点逐层划分到分到子节点, 直到遇到叶子节点, 将该叶子节点的输出值作为输出.

下面给出Python实现代码

import numpy as np
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D

#定义一个简单的树结构
class RTree:
def __init__(self,data,z,slicedIdx):
self.data =data
self.z =z
self.isLeaf = True
self.slicedIdx = slicedIdx #节点上只保存数据的序号,不保存数据子集,节约内存
self.left =None
self.right = None
self.output = np.mean(z[slicedIdx])
self.j = None
self.s = None
#本节点所带的子数据如果大于1个,则生成两个叶子节点,本节点不再是叶子节点
def grow(self):
if len(self.slicedIdx)>1:
j,s,_ = bestDivi(self.data,self.z,self.slicedIdx)
leftIdx, rightIdx = [], []
for i in self.slicedIdx:
if self.data[i,j] leftIdx.append(i)
else:
rightIdx.append(i)
self.isLeaf =False
self.left = RTree(self.data,self.z,leftIdx)
self.right = RTree(self.data,self.z,rightIdx)
self.j=j
self.s=s
def err(self):
return np.mean((self.z[self.slicedIdx]-self.output)**2)

#计算平方差
def squaErr(data,output,slicedIdx,j,s):
#挑选数据子集
region1 = []
region2 = []
for i in slicedIdx:
if data[i,j] region1.append(i)
else:
region2.append(i)
#计算子集上的平均输出
c1 = np.mean(output[region1])
err1 = np.sum((output[region1]-c1)**2)

c2 = np.mean(output[region2])
err2 = np.sum((output[region2]-c2)**2)
#返回平方差
return err1+err2

#用于选择最佳划分属性j和最切分点s
def bestDivi(data,z,slicedIdx):
min_j = 0
sortedValue = np.sort(data[slicedIdx][:,min_j])
min_s = (sortedValue[0]+sortedValue[1])/2
err = squaErr(data,z,slicedIdx,min_j,min_s)
#遍历属性
for j in range(data.shape[1]):
#产生某个属性值的分割点集合
sortedValue = np.sort(data[slicedIdx][:,j])
sliceValue = (sortedValue[1:]+sortedValue[:-1])/2
for s in sliceValue:
errNew = squaErr(data,z,slicedIdx,j,s)
if errNew err = errNew
min_j = j
min_s = s

return min_j, min_s, err

#更新树
def updateTree(tree):
if tree.isLeaf:
tree.grow()
else:
updateTree(tree.left)
updateTree(tree.right)

#预测一个数据点的输出
def predict(single_data,init_tree):
tree = init_tree
while True:
if tree.isLeaf:
return tree.output
else:
if single_data[tree.j] tree = tree.left
else:
tree = tree.right

#利用z=x+y+noise 人为生成一个数据集, 具有2个特征
n_samples = 300
points = np.random.rand(n_samples,2)
z = points[:,0]+points[:,1] + 0.2*(np.random.rand(n_samples)-0.5)
#生成根节点
root = RTree(points,z,range(n_samples))
#进行五次生长, 观测每次生长过后的拟合效果
for ii in range(5):
updateTree(root)
z_predicted = np.array([predict(p,root) for p in points])
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111,projection="3d")
ax.scatter(points[:,0],points[:,1],z)
ax.scatter(points[:,0],points[:,1],z_predicted)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

可见随着树的加深, 训练误差逐渐减小, 可见如果树足够深, 训练误差会达到0. 值得指出得是, 对于二分类问题, 树的VC维是无穷, 规模有限的数据集总可以被打散(scatter).

参考文献:

[1]李航. 统计学习方法.


推荐阅读
author-avatar
13578945682a_699
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有