热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于决策树的性别分类分析

本文旨在探讨如何利用决策树算法实现对男女性别的分类。通过引入信息熵和信息增益的概念,结合具体的数据集,详细介绍了决策树的构建过程,并展示了其在实际应用中的效果。
### 引言

本文是笔者在学习机器学习与人工智能课程时完成的一个作业,任务是使用决策树算法对男女进行分类。该任务虽然简单,但涉及到了决策树的核心原理和实现方法。

### 数据集描述

本研究使用的数据集包含四个特征:头发、声音、脸型和肤质。每个特征有两种可能的取值(如长/短、粗/细、方/圆、粗糙/细腻),共有10个样本,其中男性和女性各5个。以下是数据集的示例图(图1):

![图1 数据集呈现](https://img.php1.cn/3cd4a/1eebe/cd5/99b88427bc9ce0dc.webp)

### 决策树构建原理

决策树是一种监督学习方法,用于分类或回归问题。其核心思想是在每个节点上选择一个最优属性进行分裂,使子节点尽可能纯。这里我们采用信息熵来衡量样本的纯度。信息熵定义为:

\[ Ent(D) = - \sum_{k=1}^{|Y|} p_k \log_2 p_k \]

其中,\(D\)表示样本集,\(|Y|\)表示类别数,\(p_k\)表示第\(k\)类样本的比例。信息熵越小,说明样本的纯度越高。

对于给定的数据集,计算初始信息熵为:

\[ Ent(D) = - (\frac{5}{10}\log_2 \frac{5}{10} + \frac{5}{10}\log_2 \frac{5}{10}) = 1 \]

接下来需要考虑如何选择最优属性进行分裂。为此,我们引入了信息增益的概念,即用样本集的总信息熵减去属性划分后的加权信息熵。信息增益越大,意味着该属性对分类的贡献越大。

例如,以“头发”属性为例,当头发长度为“长”时,有1个女生和3个男生;当头发长度为“短”时,有4个女生和2个男生。因此,“头发”的信息增益计算如下:

\[ Ent(D_{(a=long)}) = - ({1 \over 4}\log_2 {1 \over 4} + {3 \over 4}\log_2 {3 \over 4}) \approx 0.81 \]

\[ Ent(D_{(a=short)}) = - ({4 \over 6}\log_2 {4 \over 6} + {2 \over 6}\log_2 {2 \over 6}) \approx 0.91 \]

\[ gain = Ent(D) - ({4 \over 10}Ent(D_{(a=long)}) + {6 \over 10}Ent(D_{(a=short)})) = 1 - {4 \over 10} \times 0.81 - {6 \over 10} \times 0.91 = 0.13 \]

### 程序实现

以下是使用Python实现决策树的代码片段:

```python
import pandas as pd
import numpy as np

# 计算信息熵
def cal_information_entropy(data):
data_label = data.iloc[:, -1]
label_class = data_label.value_counts()
Ent = 0
for k in label_class.keys():
p_k = label_class[k] / len(data_label)
Ent -= p_k * np.log2(p_k)
return Ent

# 计算信息增益
def cal_information_gain(data, a):
Ent = cal_information_entropy(data)
feature_class = data[a].value_counts()
gain = 0
for v in feature_class.keys():
weight = feature_class[v] / data.shape[0]
Ent_v = cal_information_entropy(data.loc[data[a] == v])
gain += weight * Ent_v
return Ent - gain

# 获取标签最多的那一类
def get_most_label(data):
data_label = data.iloc[:, -1]
label_sort = data_label.value_counts(sort=True)
return label_sort.keys()[0]

# 挑选最优特征
def get_best_feature(data):
features = data.columns[:-1]
res = {}
for a in features:
temp = cal_information_gain(data, a)
res[a] = temp
res = sorted(res.items(), key=lambda x: x[1], reverse=True)
return res[0][0]

# 创建决策树
def create_tree(data):
data_label = data.iloc[:, -1]
if len(data_label.value_counts()) == 1:
return data_label.values[0]
if all(len(data[i].value_counts()) == 1 for i in data.iloc[:, :-1].columns):
return get_most_label(data)
best_feature = get_best_feature(data)
Tree = {best_feature: {}}
exist_vals = pd.unique(data[best_feature])
if len(exist_vals) != len(column_count[best_feature]):
no_exist_attr = set(column_count[best_feature]) - set(exist_vals)
for no_feat in no_exist_attr:
Tree[best_feature][no_feat] = get_most_label(data)
for item in drop_exist_feature(data, best_feature):
Tree[best_feature][item[0]] = create_tree(item[1])
return Tree

# 预测函数
def predict(Tree, test_data):
first_feature = list(Tree.keys())[0]
second_dict = Tree[first_feature]
input_first = test_data.get(first_feature)
input_value = second_dict[input_first]
if isinstance(input_value, dict):
class_label = predict(input_value, test_data)
else:
class_label = input_value
return class_label

# 读取数据
data = pd.read_csv('data_word.csv', encoding='gbk')
column_count = {ds: list(pd.unique(data[ds])) for ds in data.iloc[:, :-1].columns}
dicision_Tree = create_tree(data)
print(dicision_Tree)

# 测试数据
test_data_1 = {'头发': '长', '声音': '粗', '脸型': '方', '肤质': '粗糙'}
test_data_2 = {'头发': '短', '声音': '粗', '脸型': '圆', '肤质': '细腻'}
result = predict(dicision_Tree, test_data_2)
print('分类结果为' + ('男生' if result == 1 else '女生'))
```

### 参考文献

- [CSDN博客](https://blog.csdn.net/IT23131/article/details/121068259)
- [知乎专栏](https://zhuanlan.zhihu.com/p/499238588)
推荐阅读
  • 深入浅出TensorFlow数据读写机制
    本文详细介绍TensorFlow中的数据读写操作,包括TFRecord文件的创建与读取,以及数据集(dataset)的相关概念和使用方法。 ... [详细]
  • 理解与应用:独热编码(One-Hot Encoding)
    本文详细介绍了独热编码(One-Hot Encoding)与哑变量编码(Dummy Encoding)两种方法,用于将分类变量转换为数值形式,以便于机器学习算法处理。文章不仅解释了这两种编码方式的基本原理,还探讨了它们在实际应用中的差异及选择依据。 ... [详细]
  • 交互式左右滑动导航菜单设计
    本文介绍了一种使用HTML和JavaScript实现的左右可点击滑动导航菜单的方法,适用于需要展示多个链接或项目的网页布局。 ... [详细]
  • 任务,栈, ... [详细]
  • Node.js 中 GET 和 POST 请求的数据处理
    本文详细介绍了如何在 Node.js 中使用 GET 和 POST 方法来处理客户端发送的数据。通过示例代码展示了如何解析 URL 参数和表单数据,并提供了完整的实现步骤。 ... [详细]
  • 本文详细介绍了ASP.NET缓存的基本概念和使用方法,包括输出缓存、数据缓存及其高级特性,如缓存依赖、自定义缓存和缓存配置文件等。通过合理利用这些缓存技术,可以显著提升Web应用程序的性能。 ... [详细]
  • 使用WinForms 实现 RabbitMQ RPC 示例
    本文通过两个WinForms应用程序演示了如何使用RabbitMQ实现远程过程调用(RPC)。一个应用作为客户端发送请求,另一个应用作为服务端处理请求并返回响应。 ... [详细]
  • 本文介绍了一种根据目标检测结果,从原始XML文件中提取并分析特定类别的方法。通过解析XML文件,筛选出特定类别的图像和标注信息,并保存到新的文件夹中,以便进一步分析和处理。 ... [详细]
  • 深入解析Hadoop的核心组件与工作原理
    本文详细介绍了Hadoop的三大核心组件:分布式文件系统HDFS、资源管理器YARN和分布式计算框架MapReduce。通过分析这些组件的工作机制,帮助读者更好地理解Hadoop的架构及其在大数据处理中的应用。 ... [详细]
  • 本文探讨了如何使用pg-promise库在PostgreSQL中高效地批量插入多条记录,包括通过事务和单一查询两种方法。 ... [详细]
  • 掌握Mosek矩阵运算,轻松应对优化挑战
    本篇文章继续深入探讨Mosek学习笔记系列,特别是矩阵运算部分,这对于优化问题的解决至关重要。通过本文,您将了解到如何高效地使用Mosek进行矩阵初始化、线性代数运算及约束域的设定。 ... [详细]
  • Spring Cloud Config 使用 Vault 作为配置存储
    本文探讨了如何在Spring Cloud Config中集成HashiCorp Vault作为配置存储解决方案,基于Spring Cloud Hoxton.RELEASE及Spring Boot 2.2.1.RELEASE版本。文章还提供了详细的配置示例和实践建议。 ... [详细]
  • LambdaMART算法详解
    本文详细介绍了LambdaMART算法的背景、原理及其在信息检索中的应用。首先回顾了LambdaMART的发展历程,包括其前身RankNet和LambdaRank,然后深入探讨了LambdaMART如何结合梯度提升决策树(GBDT)和LambdaRank来优化排序问题。 ... [详细]
  • 利用YAML配置Resilience4J的Circuit Breaker
    本文探讨了Resilience4j作为现代Java应用程序中不可或缺的容错工具,特别介绍了如何通过YAML文件配置Circuit Breaker以提高服务的弹性和稳定性。 ... [详细]
  • 本文档提供了如何使用C#代码从客户订单中提取产品信息的方法,适用于需要处理和分析产品数据的应用场景。 ... [详细]
author-avatar
怪物-pp_912
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有