热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于决策树的性别分类分析

本文旨在探讨如何利用决策树算法实现对男女性别的分类。通过引入信息熵和信息增益的概念,结合具体的数据集,详细介绍了决策树的构建过程,并展示了其在实际应用中的效果。
### 引言

本文是笔者在学习机器学习与人工智能课程时完成的一个作业,任务是使用决策树算法对男女进行分类。该任务虽然简单,但涉及到了决策树的核心原理和实现方法。

### 数据集描述

本研究使用的数据集包含四个特征:头发、声音、脸型和肤质。每个特征有两种可能的取值(如长/短、粗/细、方/圆、粗糙/细腻),共有10个样本,其中男性和女性各5个。以下是数据集的示例图(图1):

![图1 数据集呈现](https://img.php1.cn/3cd4a/1eebe/cd5/99b88427bc9ce0dc.webp)

### 决策树构建原理

决策树是一种监督学习方法,用于分类或回归问题。其核心思想是在每个节点上选择一个最优属性进行分裂,使子节点尽可能纯。这里我们采用信息熵来衡量样本的纯度。信息熵定义为:

\[ Ent(D) = - \sum_{k=1}^{|Y|} p_k \log_2 p_k \]

其中,\(D\)表示样本集,\(|Y|\)表示类别数,\(p_k\)表示第\(k\)类样本的比例。信息熵越小,说明样本的纯度越高。

对于给定的数据集,计算初始信息熵为:

\[ Ent(D) = - (\frac{5}{10}\log_2 \frac{5}{10} + \frac{5}{10}\log_2 \frac{5}{10}) = 1 \]

接下来需要考虑如何选择最优属性进行分裂。为此,我们引入了信息增益的概念,即用样本集的总信息熵减去属性划分后的加权信息熵。信息增益越大,意味着该属性对分类的贡献越大。

例如,以“头发”属性为例,当头发长度为“长”时,有1个女生和3个男生;当头发长度为“短”时,有4个女生和2个男生。因此,“头发”的信息增益计算如下:

\[ Ent(D_{(a=long)}) = - ({1 \over 4}\log_2 {1 \over 4} + {3 \over 4}\log_2 {3 \over 4}) \approx 0.81 \]

\[ Ent(D_{(a=short)}) = - ({4 \over 6}\log_2 {4 \over 6} + {2 \over 6}\log_2 {2 \over 6}) \approx 0.91 \]

\[ gain = Ent(D) - ({4 \over 10}Ent(D_{(a=long)}) + {6 \over 10}Ent(D_{(a=short)})) = 1 - {4 \over 10} \times 0.81 - {6 \over 10} \times 0.91 = 0.13 \]

### 程序实现

以下是使用Python实现决策树的代码片段:

```python
import pandas as pd
import numpy as np

# 计算信息熵
def cal_information_entropy(data):
data_label = data.iloc[:, -1]
label_class = data_label.value_counts()
Ent = 0
for k in label_class.keys():
p_k = label_class[k] / len(data_label)
Ent -= p_k * np.log2(p_k)
return Ent

# 计算信息增益
def cal_information_gain(data, a):
Ent = cal_information_entropy(data)
feature_class = data[a].value_counts()
gain = 0
for v in feature_class.keys():
weight = feature_class[v] / data.shape[0]
Ent_v = cal_information_entropy(data.loc[data[a] == v])
gain += weight * Ent_v
return Ent - gain

# 获取标签最多的那一类
def get_most_label(data):
data_label = data.iloc[:, -1]
label_sort = data_label.value_counts(sort=True)
return label_sort.keys()[0]

# 挑选最优特征
def get_best_feature(data):
features = data.columns[:-1]
res = {}
for a in features:
temp = cal_information_gain(data, a)
res[a] = temp
res = sorted(res.items(), key=lambda x: x[1], reverse=True)
return res[0][0]

# 创建决策树
def create_tree(data):
data_label = data.iloc[:, -1]
if len(data_label.value_counts()) == 1:
return data_label.values[0]
if all(len(data[i].value_counts()) == 1 for i in data.iloc[:, :-1].columns):
return get_most_label(data)
best_feature = get_best_feature(data)
Tree = {best_feature: {}}
exist_vals = pd.unique(data[best_feature])
if len(exist_vals) != len(column_count[best_feature]):
no_exist_attr = set(column_count[best_feature]) - set(exist_vals)
for no_feat in no_exist_attr:
Tree[best_feature][no_feat] = get_most_label(data)
for item in drop_exist_feature(data, best_feature):
Tree[best_feature][item[0]] = create_tree(item[1])
return Tree

# 预测函数
def predict(Tree, test_data):
first_feature = list(Tree.keys())[0]
second_dict = Tree[first_feature]
input_first = test_data.get(first_feature)
input_value = second_dict[input_first]
if isinstance(input_value, dict):
class_label = predict(input_value, test_data)
else:
class_label = input_value
return class_label

# 读取数据
data = pd.read_csv('data_word.csv', encoding='gbk')
column_count = {ds: list(pd.unique(data[ds])) for ds in data.iloc[:, :-1].columns}
dicision_Tree = create_tree(data)
print(dicision_Tree)

# 测试数据
test_data_1 = {'头发': '长', '声音': '粗', '脸型': '方', '肤质': '粗糙'}
test_data_2 = {'头发': '短', '声音': '粗', '脸型': '圆', '肤质': '细腻'}
result = predict(dicision_Tree, test_data_2)
print('分类结果为' + ('男生' if result == 1 else '女生'))
```

### 参考文献

- [CSDN博客](https://blog.csdn.net/IT23131/article/details/121068259)
- [知乎专栏](https://zhuanlan.zhihu.com/p/499238588)
推荐阅读
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • 扫描线三巨头 hdu1928hdu 1255  hdu 1542 [POJ 1151]
    学习链接:http:blog.csdn.netlwt36articledetails48908031学习扫描线主要学习的是一种扫描的思想,后期可以求解很 ... [详细]
  • 机器学习中的相似度度量与模型优化
    本文探讨了机器学习中常见的相似度度量方法,包括余弦相似度、欧氏距离和马氏距离,并详细介绍了如何通过选择合适的模型复杂度和正则化来提高模型的泛化能力。此外,文章还涵盖了模型评估的各种方法和指标,以及不同分类器的工作原理和应用场景。 ... [详细]
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本题探讨如何通过最大流算法解决农场排水系统的设计问题。题目要求计算从水源点到汇合点的最大水流速率,使用经典的EK(Edmonds-Karp)和Dinic算法进行求解。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • PHP 5.5.0rc1 发布:深入解析 Zend OPcache
    2013年5月9日,PHP官方发布了PHP 5.5.0rc1和PHP 5.4.15正式版,这两个版本均支持64位环境。本文将详细介绍Zend OPcache的功能及其在Windows环境下的配置与测试。 ... [详细]
  • 本题通过将每个矩形视为一个节点,根据其相对位置构建拓扑图,并利用深度优先搜索(DFS)或状态压缩动态规划(DP)求解最小涂色次数。本文详细解析了该问题的建模思路与算法实现。 ... [详细]
  • 本文探讨了《魔兽世界》中红蓝两方阵营在备战阶段的策略与实现方法,通过代码展示了双方如何根据资源和兵种特性进行战士生产。 ... [详细]
author-avatar
怪物-pp_912
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有