热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于决策树的性别分类分析

本文旨在探讨如何利用决策树算法实现对男女性别的分类。通过引入信息熵和信息增益的概念,结合具体的数据集,详细介绍了决策树的构建过程,并展示了其在实际应用中的效果。
### 引言

本文是笔者在学习机器学习与人工智能课程时完成的一个作业,任务是使用决策树算法对男女进行分类。该任务虽然简单,但涉及到了决策树的核心原理和实现方法。

### 数据集描述

本研究使用的数据集包含四个特征:头发、声音、脸型和肤质。每个特征有两种可能的取值(如长/短、粗/细、方/圆、粗糙/细腻),共有10个样本,其中男性和女性各5个。以下是数据集的示例图(图1):

![图1 数据集呈现](https://img.php1.cn/3cd4a/1eebe/cd5/99b88427bc9ce0dc.webp)

### 决策树构建原理

决策树是一种监督学习方法,用于分类或回归问题。其核心思想是在每个节点上选择一个最优属性进行分裂,使子节点尽可能纯。这里我们采用信息熵来衡量样本的纯度。信息熵定义为:

\[ Ent(D) = - \sum_{k=1}^{|Y|} p_k \log_2 p_k \]

其中,\(D\)表示样本集,\(|Y|\)表示类别数,\(p_k\)表示第\(k\)类样本的比例。信息熵越小,说明样本的纯度越高。

对于给定的数据集,计算初始信息熵为:

\[ Ent(D) = - (\frac{5}{10}\log_2 \frac{5}{10} + \frac{5}{10}\log_2 \frac{5}{10}) = 1 \]

接下来需要考虑如何选择最优属性进行分裂。为此,我们引入了信息增益的概念,即用样本集的总信息熵减去属性划分后的加权信息熵。信息增益越大,意味着该属性对分类的贡献越大。

例如,以“头发”属性为例,当头发长度为“长”时,有1个女生和3个男生;当头发长度为“短”时,有4个女生和2个男生。因此,“头发”的信息增益计算如下:

\[ Ent(D_{(a=long)}) = - ({1 \over 4}\log_2 {1 \over 4} + {3 \over 4}\log_2 {3 \over 4}) \approx 0.81 \]

\[ Ent(D_{(a=short)}) = - ({4 \over 6}\log_2 {4 \over 6} + {2 \over 6}\log_2 {2 \over 6}) \approx 0.91 \]

\[ gain = Ent(D) - ({4 \over 10}Ent(D_{(a=long)}) + {6 \over 10}Ent(D_{(a=short)})) = 1 - {4 \over 10} \times 0.81 - {6 \over 10} \times 0.91 = 0.13 \]

### 程序实现

以下是使用Python实现决策树的代码片段:

```python
import pandas as pd
import numpy as np

# 计算信息熵
def cal_information_entropy(data):
data_label = data.iloc[:, -1]
label_class = data_label.value_counts()
Ent = 0
for k in label_class.keys():
p_k = label_class[k] / len(data_label)
Ent -= p_k * np.log2(p_k)
return Ent

# 计算信息增益
def cal_information_gain(data, a):
Ent = cal_information_entropy(data)
feature_class = data[a].value_counts()
gain = 0
for v in feature_class.keys():
weight = feature_class[v] / data.shape[0]
Ent_v = cal_information_entropy(data.loc[data[a] == v])
gain += weight * Ent_v
return Ent - gain

# 获取标签最多的那一类
def get_most_label(data):
data_label = data.iloc[:, -1]
label_sort = data_label.value_counts(sort=True)
return label_sort.keys()[0]

# 挑选最优特征
def get_best_feature(data):
features = data.columns[:-1]
res = {}
for a in features:
temp = cal_information_gain(data, a)
res[a] = temp
res = sorted(res.items(), key=lambda x: x[1], reverse=True)
return res[0][0]

# 创建决策树
def create_tree(data):
data_label = data.iloc[:, -1]
if len(data_label.value_counts()) == 1:
return data_label.values[0]
if all(len(data[i].value_counts()) == 1 for i in data.iloc[:, :-1].columns):
return get_most_label(data)
best_feature = get_best_feature(data)
Tree = {best_feature: {}}
exist_vals = pd.unique(data[best_feature])
if len(exist_vals) != len(column_count[best_feature]):
no_exist_attr = set(column_count[best_feature]) - set(exist_vals)
for no_feat in no_exist_attr:
Tree[best_feature][no_feat] = get_most_label(data)
for item in drop_exist_feature(data, best_feature):
Tree[best_feature][item[0]] = create_tree(item[1])
return Tree

# 预测函数
def predict(Tree, test_data):
first_feature = list(Tree.keys())[0]
second_dict = Tree[first_feature]
input_first = test_data.get(first_feature)
input_value = second_dict[input_first]
if isinstance(input_value, dict):
class_label = predict(input_value, test_data)
else:
class_label = input_value
return class_label

# 读取数据
data = pd.read_csv('data_word.csv', encoding='gbk')
column_count = {ds: list(pd.unique(data[ds])) for ds in data.iloc[:, :-1].columns}
dicision_Tree = create_tree(data)
print(dicision_Tree)

# 测试数据
test_data_1 = {'头发': '长', '声音': '粗', '脸型': '方', '肤质': '粗糙'}
test_data_2 = {'头发': '短', '声音': '粗', '脸型': '圆', '肤质': '细腻'}
result = predict(dicision_Tree, test_data_2)
print('分类结果为' + ('男生' if result == 1 else '女生'))
```

### 参考文献

- [CSDN博客](https://blog.csdn.net/IT23131/article/details/121068259)
- [知乎专栏](https://zhuanlan.zhihu.com/p/499238588)
推荐阅读
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • 深入解析:手把手教你构建决策树算法
    本文详细介绍了机器学习中广泛应用的决策树算法,通过天气数据集的实例演示了ID3和CART算法的手动推导过程。文章长度约2000字,建议阅读时间5分钟。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文介绍了如何使用JQuery实现省市二级联动和表单验证。首先,通过change事件监听用户选择的省份,并动态加载对应的城市列表。其次,详细讲解了使用Validation插件进行表单验证的方法,包括内置规则、自定义规则及实时验证功能。 ... [详细]
  • 本文详细介绍了Java中org.eclipse.ui.forms.widgets.ExpandableComposite类的addExpansionListener()方法,并提供了多个实际代码示例,帮助开发者更好地理解和使用该方法。这些示例来源于多个知名开源项目,具有很高的参考价值。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • Kubernetes 持久化存储与数据卷详解
    本文深入探讨 Kubernetes 中持久化存储的使用场景、PV/PVC/StorageClass 的基本操作及其实现原理,旨在帮助读者理解如何高效管理容器化应用的数据持久化需求。 ... [详细]
  • DNN Community 和 Professional 版本的主要差异
    本文详细解析了 DotNetNuke (DNN) 的两种主要版本:Community 和 Professional。通过对比两者的功能和附加组件,帮助用户选择最适合其需求的版本。 ... [详细]
  • UNP 第9章:主机名与地址转换
    本章探讨了用于在主机名和数值地址之间进行转换的函数,如gethostbyname和gethostbyaddr。此外,还介绍了getservbyname和getservbyport函数,用于在服务器名和端口号之间进行转换。 ... [详细]
  • ImmutableX Poised to Pioneer Web3 Gaming Revolution
    ImmutableX is set to spearhead the evolution of Web3 gaming, with its innovative technologies and strategic partnerships driving significant advancements in the industry. ... [详细]
  • 本文详细介绍了macOS系统的核心组件,包括如何管理其安全特性——系统完整性保护(SIP),并探讨了不同版本的更新亮点。对于使用macOS系统的用户来说,了解这些信息有助于更好地管理和优化系统性能。 ... [详细]
  • 本文介绍了如何通过 Maven 依赖引入 SQLiteJDBC 和 HikariCP 包,从而在 Java 应用中高效地连接和操作 SQLite 数据库。文章提供了详细的代码示例,并解释了每个步骤的实现细节。 ... [详细]
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • 社交网络中的级联行为 ... [详细]
author-avatar
怪物-pp_912
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有