热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

python程序实现rep后剪枝算法

背景在使用决策树模型时,如果训练集中的样本数很多,则会使得生成的决策树过于庞大,即分化出了很多的枝节。这时会产生过拟合问题,

背景

在使用决策树模型时,如果训练集中的样本数很多,则会使得生成的决策树过于庞大,即分化出了很多的枝节。这时会产生过拟合问题,也就是在模型在训练集上的表现效果良好,而在测试集的效果却很差。因此在生成一棵决策树之后,需要对它进行必要的剪枝,从而提高它的泛化能力。本文将讲述后剪枝算法——REP方法。


原理

剪枝是指将决策树的一些枝节去掉,将中间节点变成叶子节点,该叶子节点的预测值便是该分组训练样本yyy值的均值。剪枝算法分为预剪枝和后剪枝,预剪枝是在决策树生成的过程中同步进行,而后剪枝是在决策树生成完之后再剪枝。

REP方法也称为错误率降低剪枝,它是一类最基础、最简单的后剪枝算法,是其他剪枝算法的基础。主要过程是将训练集分为两个集合N1N_1N1N2N_2N2,可以称为训练集中的训练集和训练集中的验证集。N1N_1N1用来生成决策树,N2N_2N2用来验证剪枝前后的模型效果。具体是先用N1N_1N1来生成决策树,然后自底向上遍历所有中间节点,对于每个中间节点,比较剪枝前后的两棵决策树在验证集N2N_2N2上的效果,这个效果体现在N2N_2N2通过两棵决策树得到的预测值与原始实际值的误差平方和,若剪枝后的误差平方和更小,则对决策树进行剪枝,反之则不进行剪枝。


例子

假设通过对数据集N1N_1N1进行训练,得到了如下的决策树

现在要自底向上进行剪枝,对象是中间节点,对应到图中依次是节点5、2、3。对于节点5,先将它的左右枝8和9剪掉,得到剪枝前和剪枝后的两棵树

将数据集N2N_2N2的特征xxx分别代入这两棵树,得到两组预测值,然后通过比较两组数据的误差平方和来决策是否进行剪枝。之后再考虑节点2,最后考虑节点3。


程序实现


重新定义树结构

为了方便处理不同节点间的调用,CART回归树的树模型不再用字典进行存储,而改用自定义的类对象(参考leetcode中的树节点),每个节点可以通过成员变量来调用分裂出来的左右节点。

class TreeNode:def __init__(self, val, fea_name=None, fea_c=None):self.left = Noneself.right = Noneself.val = round(val,2)self.fea_name = fea_name self.fea_c = fea_c if fea_c is None else round(fea_c,2)

变量valvalval是指当前节点数据集的yyy值的均值,若当前节点是叶子节点,则该变量代表这种分支的预测值,若是中间节点,则可以表示为对该节点进行剪枝后的节点预测值。


生成剪枝后的子树

def sub_tree(tree, num): # 返回后序剪枝得到的子树stack = [(False, tree)]while stack:flag, t = stack.pop()if not t:continueif flag:if t.left or t.right:if num==0:t.left = Nonet.right = Nonereturn treeelse:num -= 1else:stack.append((True, t))stack.append((False, t.right))stack.append((False, t.left))return tree

采用后序遍历的方式来搜索中间节点,参数numnumnum是为了控制对应序号的中间节点,因为并不是剪去每个中间节点都能提高性能,通过numnumnum可以避开不想剪去的中间节点。


计算中间节点的个数

def mid_leaf_num(tree): if not tree or (not tree.left and not tree.right):return 0return 1 + mid_leaf_num(tree.left) + mid_leaf_num(tree.right)

效果比较函数

def ifmore(self, temp_tree, test_x, test_y):orig_ = []temp_ = []for i in range(len(test_x)):orig_.append(self.check(self.tree, test_x[i]))temp_.append(self.check(temp_tree, test_x[i]))orig_sum = sum(np.power(np.array(orig_)-test_y, 2))temp_sum = sum(np.power(np.array(temp_)-test_y, 2))if orig_sum>temp_sum: # and (orig_sum-temp_sum)/orig_sum>0.0001:self.tree = temp_treereturn Trueelse:return False

REP剪枝函数

def prune_tree(self, test_x, test_y):mid_num &#61; mid_leaf_num(self.tree)i &#61; 0while i<mid_num: temp_tree &#61; sub_tree(self.tree, i)if self.ifmore(temp_tree, test_x, test_y):i &#61; 0mid_num -&#61; 1else:i &#43;&#61; 1

实例化演示

X,y &#61; make_regression(n_samples&#61;1000, n_features&#61;4, noise&#61;0.1)
X_name &#61; np.array(list(&#39;abcd&#39;))
clf &#61; Tree_Regress()
train_x, test_x, train_y, test_y &#61; train_test_split(X, y, test_size&#61;0.30)
train_size &#61; len(train_x)//4
train_test_x, train_test_y &#61; train_x[:train_size], train_y[:train_size]
train_train_x, train_train_y &#61; train_x[train_size:], train_y[train_size:]
print(&#39;不剪枝&#39;)
clf.fit(train_x, train_y, X_name)
predict_y &#61; clf.predict(test_x)
pre_error &#61; sum(np.power(test_y-predict_y,2))
print(&#39;误差为&#xff1a;&#39;, pre_error,&#39; 节点数&#xff1a;&#39;, clf.node_num())print(&#39;有剪枝&#39;)
clf.fit(train_train_x, train_train_y, X_name)
clf.prune_tree(train_test_x, train_test_y)
predict_y &#61; clf.predict(test_x)
pre_error &#61; sum(np.power(test_y-predict_y,2))
print(&#39;误差为&#xff1a;&#39;, pre_error,&#39; 节点数&#xff1a;&#39;, clf.node_num())# 不剪枝
# 误差为&#xff1a; 9860.41223121167 节点数&#xff1a; 249
# 有剪枝
# 误差为&#xff1a; 6892.04066950184 节点数&#xff1a; 33

在对模型进行后剪枝之后&#xff0c;模型的泛化能力有所提升。


不足

REP方法虽然在一定程度上简化了决策树&#xff0c;提高了模型的性能&#xff0c;但在有些情况下反而会造成相反的结果&#xff0c;使得模型表现更差。我试了几个不同的数据集&#xff0c;发现效果其实不是很好&#xff0c;可能是这个方法考虑到的东西比较少&#xff0c;它单单考虑到了模型拟合的误差平方和&#xff0c;却没考虑生成的节点个数&#xff0c;粗暴地将影响模型性能的枝节都剪去&#xff0c;使得模型太过简单。另外&#xff0c;该方法的计算开销是很大的&#xff0c;需要遍历搜索两次中间节点。


----end----

推荐阅读
  • 使用R语言进行Foodmart数据的关联规则分析与可视化
    本文探讨了如何利用R语言中的arules和arulesViz包对Foodmart数据集进行关联规则的挖掘与可视化。文章首先介绍了数据集的基本情况,然后逐步展示了如何进行数据预处理、规则挖掘及结果的图形化呈现。 ... [详细]
  • 探索CNN的可视化技术
    神经网络的可视化在理论学习与实践应用中扮演着至关重要的角色。本文深入探讨了三种有效的CNN(卷积神经网络)可视化方法,旨在帮助读者更好地理解和优化模型。 ... [详细]
  • 本文探讨了在已知最终数组尺寸不会超过5000x10的情况下,如何利用预分配和调整大小的方法来优化Numpy数组的创建过程,以提高性能并减少内存消耗。 ... [详细]
  • 本文介绍了如何利用OpenCV库进行图像的边缘检测,并通过Canny算法提取图像中的边缘。随后,文章详细说明了如何识别图像中的特定形状(如矩形),并应用四点变换技术对目标区域进行透视校正。 ... [详细]
  • td{border:1pxsolid#808080;}参考:和FMX相关的类(表)TFmxObjectIFreeNotification ... [详细]
  • 来自FallDream的博客,未经允许,请勿转载,谢谢。一天一套noi简直了.昨天勉强做完了noi2011今天教练又丢出来一套noi ... [详细]
  • This article explores the process of integrating Promises into Ext Ajax calls for a more functional programming approach, along with detailed steps on testing these asynchronous operations. ... [详细]
  • 本文详细探讨了编程中的命名空间与作用域概念,包括其定义、类型以及在不同上下文中的应用。 ... [详细]
  • 本文详细介绍了Socket在Linux内核中的实现机制,包括基本的Socket结构、协议操作集以及不同协议下的具体实现。通过这些内容,读者可以更好地理解Socket的工作原理。 ... [详细]
  • 本文详细介绍了 Redis 中的主要数据类型,包括 String、Hash、List、Set、ZSet、Geo 和 HyperLogLog,并提供了每种类型的基本操作命令和应用场景。 ... [详细]
  • 本文介绍了多维缩放(MDS)技术,这是一种将高维数据映射到低维空间的方法,通过保持原始数据间的关系,以便于可视化和分析。文章详细描述了MDS的原理和实现过程,并提供了Python代码示例。 ... [详细]
  • AI炼金术:KNN分类器的构建与应用
    本文介绍了如何使用Python及其相关库(如NumPy、scikit-learn和matplotlib)构建KNN分类器模型。通过详细的数据准备、模型训练及新样本预测的过程,展示KNN算法的实际操作步骤。 ... [详细]
  • 深入探讨前端代码优化策略
    本文深入讨论了前端开发中代码优化的关键技术,包括JavaScript、HTML和CSS的优化方法,旨在提升网页加载速度和用户体验。 ... [详细]
  • OBS Studio自动化实践:利用脚本批量生成录制场景
    本文探讨了如何利用OBS Studio进行高效录屏,并通过脚本实现场景的自动生成。适合对自动化办公感兴趣的读者。 ... [详细]
  • 在OpenCV 3.1.0中实现SIFT与SURF特征检测
    本文介绍如何在OpenCV 3.1.0版本中通过Python 2.7环境使用SIFT和SURF算法进行图像特征点检测。由于这些高级功能在OpenCV 3.0.0及更高版本中被移至额外的contrib模块,因此需要特别处理才能正常使用。 ... [详细]
author-avatar
ociVyouzhangzh063_1fd2bf_633
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有