热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

决策树在鸢尾花数据集上对不同特征组合的分类效果分析及模型性能比较

本文探讨了决策树算法在鸢尾花数据集上的应用,分析了不同特征组合对分类效果的影响,并对模型性能进行了详细比较。决策树作为一种层次化的分类方法,通过递归地划分特征空间,形成树状结构,每个节点代表一个特征判断,最终达到分类目的。研究结果表明,不同特征组合对模型性能有显著影响,为实际应用提供了重要参考。

一、什么是决策树

决策树算法,人如其名,结构就像一棵树,有分叉的枝丫和树叶。枝丫的分叉处是关于目标某一个特征的判断,枝丫本体则是关于该特征的判断结果,而叶子则是判断过后产生的决策结果。

d473e70766f9cc298c440dc71563be81.png

上图就是一个最为简单的分类树决策,当我们看天气预报时,根据降雨、雾霾、气温、活动范围是室内活动还是室外活动等等特征将自己的行为分类为出门和不出门。简单来说,决策树可以被看做由一大堆if-then的判断,每一条枝丫都是一条规则。

决策树算法的核心解决两个问题:

  • 如何从数据表中找出最佳节点和最佳分枝?
    决策树是对特征提问,如何找出最佳节点和最佳分支,我怎么知道哪些特征提问,才能生成有效的树呢?
  • 如何让决策树停止增长,防止过拟合?
    如果我有无数个特征,决策树会长成什么样子,他是不是会长成无数层深,我们要怎么样让它停止成长呢?怎么样防止过拟合呢?

1.1 决策树的优点

  • 计算简单,易于理解,可解释性强;
  • 比较适合处理有缺失属性的样本;
  • 能够处理不相关的特征;
  • 在相对短的时间内能够对大型数据源做出可行且效果良好的结果。

1.2 决策树的缺点

  • 容易发生过拟合(随机森林可以很大程度上减少过拟合);
  • 忽略了数据之间的相关性;
  • 对于那些各类别样本数量不一致的数据,在决策树当中,信息增益的结果偏向于那些具有更多数值的特征(只要是>- 使用了信息增益,都有这个缺点,如RF)。

决策树可以用来分类,也可以用来回归。

1.3 决策树参数

sklearn决策树的两个类:

tree.DecisionTreeClassifier()
tree.DecisionTreeRegressor()

决策树的重要参数Criterion:

Criterion这个参数正是用来决定不纯度的计算方法的。sklearn提供了两种选择:
  • “entropy”,使用信息熵(Entropy)
  • “gini”,使用基尼系数(Gini Impurity)

通常就使用基尼系数
数据维度很大,噪音很大时使用基尼系数
维度低,数据比较清晰的时候,信息熵和基尼系数没区别

决策树的重要参数random_state:

random_state用来设置分枝中的随机模式的参数,默认None,在高维度时随机性会表现更明显,低维度的数据(比如鸢尾花数据集),随机性几乎不会显现。输入任意整数,会一直长出同一棵树,让模型稳定下来。

决策树的重要参数splitter:

splitter也是用来控制决策树中的随机选项的,有两种输入值,输入”best”,决策树在分枝时虽然随机,但是还是会优先选择更重要的特征进行分枝(重要性可以通过属性feature_importances_查看),输入“random”,决策树在分枝时会更加随机,树会因为含有更多的不必要信息而更深更大,并因这些不必要信息而降低对训练集的拟合。这也是防止过拟合的一种方式。当你预测到你的模型会过拟合,用这两个参数来帮助你降低树建成之后过拟合的可能性。当然,树一旦建成,我们依然是使用剪枝参数来防止过拟合。

剪枝参数

在不加限制的情况下,一棵决策树会生长到衡量不纯度的指标最优,或者没有更多的特征可用为止。这样的决策树往往会过拟合,过拟合这就是说,它会在训练集上表现很好,在测试集上却表现糟糕。我们收集的样本数据不可能和整体的状况完全一致,因此当一棵决策树对训练数据有了过于优秀的解释性,它找出的规则必然包含了训练样本中的噪声,并使它对未知数据的拟合程度不足。

剪枝策略对决策树的影响巨大,正确的剪枝策略是优化决策树算法的核心。sklearn为我们提供了不同的剪枝策略:

1. max_depth
限制树的最大深度,超过设定深度的树枝全部剪掉
这是用得最广泛的剪枝参数,在高维度低样本量时非常有效。决策树多生长一层,对样本量的需求会增加一倍,所以限制树深度能够有效地限制过拟合。在集成算法中也非常实用。实际使用时,建议从=3开始尝试,看看拟合的效果再决定是否增加设定深度。
2. min_samples_leaf
min_samples_leaf 限定,一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生
一般搭配max_depth使用,在回归树中有神奇的效果,可以让模型变得更加平滑。这个参数的数量设置得太小会引起过拟合,设置得太大就会阻止模型学习数据。一般来说,建议从=5开始使用。如果叶节点中含有的样本量变化很大,建议输入浮点数作为样本量的百分比来使用。同时,这个参数可以保证每个叶子的最小尺寸,可以在回归问题中避免低方差,过拟合的叶子节点出现。对于类别不多的分类问题,=1通常就是最佳选择。
3. min_samples_split
min_samples_split限定,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生。
如果一个样本包20个样本,我们在不限定的情况下会不断分下去的,如果设定min_samples_split=15,那么这个节点就不会分了。

二、决策树分类

2.1 准备数据

iris = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data',header=None)
iris.columns=['SepalLengthCm','SepalWidthCm','PetalLengthCm','PetalWidthCm','Species']le = LabelEncoder()
le.fit(iris['Species'])
features = ['SepalWidthCm','PetalWidthCm']
X = iris[features]
y = le.transform(iris['Species'])

2.2 进行分类

tr = tree.DecisionTreeClassifier()
tr.fit(X,y)
score = numpy.mean(cross_val_score(tr,X,y,cv=5,scoring='accuracy'))
print('决策树分类模型平均性能得分:'+str(score))

输出的结果:决策树分类模型平均性能得分:0.933333333333

我们修改一下决策树的参数在进行一次模型性能评估。

tr = tree.DecisionTreeClassifier(criterion="entropy",random_state=10,splitter="best",max_depth=10,min_samples_leaf=5,min_samples_split=5)
score = numpy.mean(cross_val_score(tr,X,y,cv=5,scoring='accuracy'))
print('决策树分类模型平均性能得分:'+str(score))

输出结果:决策树分类模型平均性能得分:0.94,稍微好一点。参数的其他选项请自行测试。

2.3 和logistic分类对比

lm = linear_model.LogisticRegression()
score = numpy.mean(cross_val_score(lm,X,y,cv=5,scoring='accuracy'))
print('logistic回归模型平均性能得分:'+str(score))

输出的结果:logistic回归模型平均性能得分:0.94
可见,在此数据集中,logistic模型和决策树模型的准确率差不多。

如何生成决策树可视化,请参考模块graphviz

三、决策树回归

我们上面已经准备好数据了,我们只需要构造一下我们的因变量y,让它y = iris[‘PetalWidthCm’]

3.1 进行回归

y = iris['PetalWidthCm']
tr = tree.DecisionTreeRegressor()
score = numpy.mean(-cross_val_score(tr,X,y,cv=5,scoring='neg_mean_squared_error'))
print('平均性能得分:'+str(score))

输出结果为:决策树回归模型平均性能得分:0.0008

3.2 和线性回归对比

lm = linear_model.LinearRegression()
score = numpy.mean(-cross_val_score(tr,X,y,cv=5,scoring='neg_mean_squared_error'))
print('线性回归模型平均性能得分:'+str(score))

输出结果为:决策树回归模型平均性能得分:0.0042

可以看到决策树回归模型性能更好。未来我会对如何选择最优的特征值,如何选择最优的模型,如何选最优的模型参数进行详细深入的分享。

全部代码

import pandas as pd
from sklearn.model_selection import cross_val_score
import numpy
from sklearn.preprocessing import LabelEncoder
from sklearn import linear_model
from sklearn import tree
from sklearn import ensembleiris = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data',header=None)
iris.columns=['SepalLengthCm','SepalWidthCm','PetalLengthCm','PetalWidthCm','Species']le = LabelEncoder()
le.fit(iris['Species'])
features = ['SepalWidthCm','PetalWidthCm']
X = iris[features]
y = le.transform(iris['Species'])tr = tree.DecisionTreeClassifier(criterion="entropy",random_state=10,splitter="best",max_depth=10,min_samples_leaf=5,min_samples_split=5)
score = numpy.mean(cross_val_score(tr,X,y,cv=5,scoring='accuracy'))
print('决策树分类模型平均性能得分:'+str(score))lm = linear_model.LogisticRegression()
score = numpy.mean(cross_val_score(lm,X,y,cv=5,scoring='accuracy'))
print('logistic回归模型平均性能得分:'+str(score))y = iris['PetalWidthCm']
tr = tree.DecisionTreeRegressor()
score = numpy.mean(-cross_val_score(tr,X,y,cv=5,scoring='neg_mean_squared_error'))
print('决策树回归模型平均性能得分:'+str(score))lm = linear_model.LinearRegression()
score = numpy.mean(-cross_val_score(tr,X,y,cv=5,scoring='neg_mean_squared_error'))
print('线性回归模型平均性能得分:'+str(score))




推荐阅读
  • 采用IKE方式建立IPsec安全隧道
    一、【组网和实验环境】按如上的接口ip先作配置,再作ipsec的相关配置,配置文本见文章最后本文实验采用的交换机是H3C模拟器,下载地址如 ... [详细]
  • Coursera ML 机器学习
    2019独角兽企业重金招聘Python工程师标准线性回归算法计算过程CostFunction梯度下降算法多变量回归![选择特征](https:static.oschina.n ... [详细]
  • 中科院学位论文排版指南
    随着毕业季的到来,许多即将毕业的学生开始撰写学位论文。本文介绍了使用LaTeX排版学位论文的方法,特别是针对中国科学院大学研究生学位论文撰写规范指导意见的最新要求。LaTeX以其精确的控制和美观的排版效果成为许多学者的首选。 ... [详细]
  • 本题探讨了在大数据结构背景下,如何通过整体二分和CDQ分治等高级算法优化处理复杂的时间序列问题。题目设定包括节点数量、查询次数和权重限制,并详细分析了解决方案中的关键步骤。 ... [详细]
  • 目录一、salt-job管理#job存放数据目录#缓存时间设置#Others二、returns模块配置job数据入库#配置returns返回值信息#mysql安全设置#创建模块相关 ... [详细]
  • 深入理解Java字符串池机制
    本文详细解析了Java中的字符串池(String Pool)机制,探讨其工作原理、实现方式及其对性能的影响。通过具体的代码示例和分析,帮助读者更好地理解和应用这一重要特性。 ... [详细]
  • 本文介绍如何使用 Angular 6 的 HttpClient 模块来获取 HTTP 响应头,包括代码示例和常见问题的解决方案。 ... [详细]
  • 本文介绍了在Java环境中使用PDFBox和XPDF工具从PDF文件中提取文本内容的方法。重点讨论了处理中文字符集及解决相关错误的技术细节,特别是针对某些特定格式的PDF文件(如网上填写的报名表和下载的论文)遇到的问题及解决方案。 ... [详细]
  • 本文介绍如何使用MFC和ADO技术调用SQL Server中的存储过程,以查询指定小区在特定时间段内的通话统计数据。通过用户界面选择小区ID、开始时间和结束时间,系统将计算并展示小时级的通话量、拥塞率及半速率通话比例。 ... [详细]
  • 实用正则表达式有哪些
    小编给大家分享一下实用正则表达式有哪些,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下 ... [详细]
  • 主板IO用W83627THG,用VC如何取得CPU温度,系统温度,CPU风扇转速,VBat的电压. ... [详细]
  • 本文探讨了如何通过预处理器开关选择不同的类实现,并解决在特定情况下遇到的链接器错误。 ... [详细]
  • 本文提供了 CIW Dreamweaver MX2004 认证考试的详细试题解析,涵盖不同难度级别的选择题、多项选择题和判断题。通过这些题目,考生可以更好地理解考试内容并为实际考试做好准备。 ... [详细]
  • 深入解析SpringMVC核心组件:DispatcherServlet的工作原理
    本文详细探讨了SpringMVC的核心组件——DispatcherServlet的运作机制,旨在帮助有一定Java和Spring基础的开发人员理解HTTP请求是如何被映射到Controller并执行的。文章将解答以下问题:1. HTTP请求如何映射到Controller;2. Controller是如何被执行的。 ... [详细]
  • 通常情况下,修改my.cnf配置文件后需要重启MySQL服务才能使新参数生效。然而,通过特定命令可以在不重启服务的情况下实现配置的即时更新。本文将详细介绍如何在线调整MySQL配置,并验证其有效性。 ... [详细]
author-avatar
三封酒可_894
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有