《老饼讲解机器学习》
http://ml.bbbdata.com/teach#107
目录
一.问题
二.主要思路
三.代码实例
1.数据提取
2.预测函数
3.准确性测试
一.问题
在决策树模型建好之后,要提取规则布署到生产。
二.主要思路
只提取数据,在生产环境写出通用预测代码。新的模型只需替换数据即可。
备注:一般不弄成一系列的if else,写死代码不便于更换模型。
三.代码实例
1.数据提取
使用如下get_tree函数,将树数据提取成字典:
from sklearn import tree
import numpy as np
def get_tree(sk_tree):#--------------拷贝sklearn树模型关键信息--------------------children_left = sk_tree.tree_.children_left.copy() # 左节点编号children_right = sk_tree.tree_.children_right.copy() # 右节点编号feature = sk_tree.tree_.feature.copy() # 分割的变量threshold = sk_tree.tree_.threshold.copy() # 分割阈值impurity = sk_tree.tree_.impurity.copy() # 不纯度(gini)n_node_samples = sk_tree.tree_.n_node_samples.copy() # 样本个数value = sk_tree.tree_.value.copy() # 样本分布n_sample = value[0].sum() # 总样本个数node_num = len(children_left) # 节点个数depth = sk_tree.get_depth()# ------------补充节点父节点信息---------------------------parent = np.zeros(node_num).astype(int)parent[0] = -1branch_idx = np.where(children_left!=-1)[0]for i in branch_idx:parent[children_left[i]] = i parent[children_right[i]]= i #-------------存成字典----------------------------------------- tree = {'children_left':children_left,'children_right':children_right,'feature':feature,'threshold':threshold,'impurity':impurity,'n_node_samples':n_node_samples,'value':value,'depth':depth,'n_sample':n_sample,'node_num':node_num,'parent':parent}return tree
将训练好的模型sk_tree传入以上函数,转化成字典,保存成文件。
2.预测函数
在生产时使用如下tree_predict 函数预测(其它语言类似以下逻辑)。
import numpy as np
def tree_predict(tree,x):node_idx = 0t = 0while(t
3.准确性测试
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
from get_tree import get_tree
from tree_pred import tree_predict#----------------数据准备----------------------------
iris = load_iris() # 加载数据
X = iris.data
y = iris.target
#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier() # sk-learn的决策树模型
clf = clf.fit(X, y) # 用数据训练树模型构建()
#--------------将树提取成简单的字典--------------------------------
tree = get_tree(clf)
#-------------------------
#将tree持久化到服务器,服务器中用tree_predict进行预测即可
#-------------------------#------------测试函数的准确性-----------------------------
self_pred_y = np.zeros(len(y))
self_pred_prob = np.zeros((len(y),len(tree['value'][0][0])))
for i in range(X.shape[0]):pred_class,pred_prob = tree_predict(tree,X[i])self_pred_y[i] = pred_classself_pred_prob[i] = pred_prob
pred_y = clf.predict(X)
pred_prob = clf.predict_proba(X)
print("与sklearn预测结果差异个数(类别):",np.sum(pred_y != self_pred_y))
print("与sklearn预测结果差异个数(概率):",np.sum(pred_prob != self_pred_prob))
测试结果:
与sklearn预测结果差异个数(类别): 0
与sklearn预测结果差异个数(概率): 0
相关文章
《深入浅出:决策树入门简介》
《一个简单的决策树分类例子》
《sklearn决策树结果可视化》
《sklearn决策树参数详解》