热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

实用:sklearn提取决策树规则代码(附python代码)

《老饼讲解机器学习》http:ml.bbbdata.comteach#107目录一.问题二.主要思路三.代码实例1.数据提取2.预测函数3.准确性测试一.问题在决策树

《老饼讲解机器学习》http://ml.bbbdata.com/teach#107


目录

一.问题

二.主要思路

三.代码实例

1.数据提取

2.预测函数

3.准确性测试



一.问题

在决策树模型建好之后,要提取规则布署到生产。

二.主要思路

只提取数据,在生产环境写出通用预测代码。新的模型只需替换数据即可。

备注:一般不弄成一系列的if else,写死代码不便于更换模型。

三.代码实例

1.数据提取

使用如下get_tree函数,将树数据提取成字典:

from sklearn import tree
import numpy as np
def get_tree(sk_tree):#--------------拷贝sklearn树模型关键信息--------------------children_left = sk_tree.tree_.children_left.copy() # 左节点编号children_right = sk_tree.tree_.children_right.copy() # 右节点编号feature = sk_tree.tree_.feature.copy() # 分割的变量threshold = sk_tree.tree_.threshold.copy() # 分割阈值impurity = sk_tree.tree_.impurity.copy() # 不纯度(gini)n_node_samples = sk_tree.tree_.n_node_samples.copy() # 样本个数value = sk_tree.tree_.value.copy() # 样本分布n_sample = value[0].sum() # 总样本个数node_num = len(children_left) # 节点个数depth = sk_tree.get_depth()# ------------补充节点父节点信息---------------------------parent = np.zeros(node_num).astype(int)parent[0] = -1branch_idx = np.where(children_left!=-1)[0]for i in branch_idx:parent[children_left[i]] = i parent[children_right[i]]= i #-------------存成字典----------------------------------------- tree = {'children_left':children_left,'children_right':children_right,'feature':feature,'threshold':threshold,'impurity':impurity,'n_node_samples':n_node_samples,'value':value,'depth':depth,'n_sample':n_sample,'node_num':node_num,'parent':parent}return tree

将训练好的模型sk_tree传入以上函数,转化成字典,保存成文件。

2.预测函数

在生产时使用如下tree_predict 函数预测(其它语言类似以下逻辑)。

import numpy as np
def tree_predict(tree,x):node_idx = 0t = 0while(t

3.准确性测试

from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
from get_tree import get_tree
from tree_pred import tree_predict#----------------数据准备----------------------------
iris = load_iris() # 加载数据
X = iris.data
y = iris.target
#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier() # sk-learn的决策树模型
clf = clf.fit(X, y) # 用数据训练树模型构建()
#--------------将树提取成简单的字典--------------------------------
tree = get_tree(clf)
#-------------------------
#将tree持久化到服务器,服务器中用tree_predict进行预测即可
#-------------------------#------------测试函数的准确性-----------------------------
self_pred_y = np.zeros(len(y))
self_pred_prob = np.zeros((len(y),len(tree['value'][0][0])))
for i in range(X.shape[0]):pred_class,pred_prob = tree_predict(tree,X[i])self_pred_y[i] = pred_classself_pred_prob[i] = pred_prob
pred_y = clf.predict(X)
pred_prob = clf.predict_proba(X)
print("与sklearn预测结果差异个数(类别):",np.sum(pred_y != self_pred_y))
print("与sklearn预测结果差异个数(概率):",np.sum(pred_prob != self_pred_prob))

 测试结果:

与sklearn预测结果差异个数(类别): 0
与sklearn预测结果差异个数(概率): 0

相关文章

《深入浅出:决策树入门简介》

《一个简单的决策树分类例子》

《sklearn决策树结果可视化》

《sklearn决策树参数详解》


推荐阅读
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • 本文详细介绍了Java编程语言中的核心概念和常见面试问题,包括集合类、数据结构、线程处理、Java虚拟机(JVM)、HTTP协议以及Git操作等方面的内容。通过深入分析每个主题,帮助读者更好地理解Java的关键特性和最佳实践。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 深入解析 Apache Shiro 安全框架架构
    本文详细介绍了 Apache Shiro,一个强大且灵活的开源安全框架。Shiro 专注于简化身份验证、授权、会话管理和加密等复杂的安全操作,使开发者能够更轻松地保护应用程序。其核心目标是提供易于使用和理解的API,同时确保高度的安全性和灵活性。 ... [详细]
  • 深入解析 Spring Security 用户认证机制
    本文将详细介绍 Spring Security 中用户登录认证的核心流程,重点分析 AbstractAuthenticationProcessingFilter 和 AuthenticationManager 的工作原理。通过理解这些组件的实现,读者可以更好地掌握 Spring Security 的认证机制。 ... [详细]
  • PHP 过滤器详解
    本文深入探讨了 PHP 中的过滤器机制,包括常见的 $_SERVER 变量、filter_has_var() 函数、filter_id() 函数、filter_input() 函数及其数组形式、filter_list() 函数以及 filter_var() 和其数组形式。同时,详细介绍了各种过滤器的用途和用法。 ... [详细]
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • 深入解析:手把手教你构建决策树算法
    本文详细介绍了机器学习中广泛应用的决策树算法,通过天气数据集的实例演示了ID3和CART算法的手动推导过程。文章长度约2000字,建议阅读时间5分钟。 ... [详细]
  • 本文详细介绍了Java中org.w3c.dom.Text类的splitText()方法,通过多个代码示例展示了其实际应用。该方法用于将文本节点在指定位置拆分为两个节点,并保持在文档树中。 ... [详细]
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • 题目Link题目学习link1题目学习link2题目学习link3%%%受益匪浅!-----&# ... [详细]
author-avatar
丢失的面包树
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有