当前位置: 首页 > news >正文

决策树原理与 Sklearn 实战

目录

引言:为什么要学决策树?

一、决策树核心原理:如何 “生长” 一棵决策树?

1.1 信息熵:衡量不确定性的 “尺子”

1.1.1 信息熵的定义

1.1.2 信息熵的直观理解

1.2 特征选择:用 “信息增益” 找最优分裂特征

1.2.1 信息增益的定义

1.2.2 实例:用贷款数据计算信息增益

1.3 决策树的分裂停止条件

二、主流决策树算法对比:ID3、C4.5、CART

关键补充:基尼系数是什么?

三、Sklearn 实战:决策树分类与可视化

3.1 环境准备

3.2 完整代码流程

步骤 1:导入库与加载数据集

步骤 2:划分训练集与测试集

步骤 3:实例化并训练决策树模型

步骤 4:模型预测与评估

步骤 5:决策树可视化

步骤 6:特征重要性分析

四、决策树的优缺点与优化策略

4.1 决策树的优缺点

优点:

缺点:

4.2 优化策略

1. 剪枝(Pruning):防止过拟合的核心手段

2. 集成学习:提升稳定性与泛化能力

3. 特征工程:减少噪声与冗余

五、总结


引言:为什么要学决策树?

在机器学习领域,决策树是最经典、最易理解的算法之一。它的核心思想源于人类的决策过程 —— 比如我们判断 “是否出门郊游” 时,会依次考虑 “是否下雨”“温度是否适宜”“是否有时间” 等条件,最终得出结论。这种 “if-then” 的分支逻辑,让决策树模型天生具备可解释性强、可视化友好的优势,无需复杂的数学推导就能理解模型决策过程。

决策树的应用场景非常广泛:

  • 金融风控:根据用户收入、征信记录、负债情况判断贷款违约风险;
  • 医疗诊断:依据患者症状、检查指标判断疾病类型;
  • 电商推荐:基于用户购买历史、浏览行为分类用户偏好;
  • 工业质检:通过产品尺寸、重量、材质等特征判断是否合格。

本文将从核心原理→算法对比→Sklearn 实战→优化策略四个维度,带你彻底掌握决策树,零基础也能轻松上手。

一、决策树核心原理:如何 “生长” 一棵决策树?

决策树的构建过程本质是 **“特征选择→节点分裂→树修剪”** 的循环,核心是解决两个问题:“用哪个特征分裂节点?”“分裂到什么时候停止?”。要回答这两个问题,需要先理解 “信息熵”“信息增益” 等关键概念。

1.1 信息熵:衡量不确定性的 “尺子”

决策树的本质是降低不确定性—— 从根节点(所有样本混合)到叶节点(样本类别单一),每一次分裂都要让 “类别混乱程度” 下降。而 “信息熵(Information Entropy)” 就是量化这种 “混乱程度” 的指标。

1.1.1 信息熵的定义

香农在 1948 年提出信息熵,单位为 “比特(bit)”,公式如下: \(H(X) = -\sum_{x \in X} P(x) \log_2 P(x)\) 其中:

  • X 是样本集合的类别空间(比如 “贷款批准”“贷款拒绝”);
  • \(P(x)\) 是某类别 x 在样本集中的概率(比如 100 个样本中 60 个批准,\(P(批准)=0.6\));
  • 负号是为了保证结果非负(因为 \(\log_2 P(x)\) 对 \(P(x) \in (0,1)\) 是负数)。
1.1.2 信息熵的直观理解

信息熵越高,样本类别越混乱;信息熵越低,类别越集中。 举个例子:

  • 若 100 个贷款样本中,100 个都批准(\(P(批准)=1\),\(P(拒绝)=0\)),则 \(H(X) = -1 \times \log_2 1 - 0 \times \log_2 0 = 0\)(完全确定,无混乱);
  • 若 50 个批准、50 个拒绝(\(P(批准)=0.5\),\(P(拒绝)=0.5\)),则 \(H(X) = -0.5\log_2 0.5 -0.5\log_2 0.5 = 1\)(最混乱,熵最大)。

1.2 特征选择:用 “信息增益” 找最优分裂特征

有了信息熵,我们就可以通过 “信息增益” 判断哪个特征对降低不确定性最有效 ——信息增益越大,该特征的分类能力越强

1.2.1 信息增益的定义

信息增益 \(g(D,A)\) 是 “分裂前的信息熵 \(H(D)\)” 与 “分裂后各子节点信息熵的加权平均 \(H(D|A)\)” 的差值: \(g(D,A) = H(D) - H(D|A)\) 其中:

  • D 是当前样本集合;
  • A 是待选择的分裂特征;
  • \(H(D|A)\) 是 “特征 A 条件下的条件熵”,计算方式为:先按 A 的取值划分样本为多个子集 \(D_1,D_2,...,D_k\),再计算每个子集的熵,最后按子集大小加权求和。
1.2.2 实例:用贷款数据计算信息增益

假设我们有 10 个贷款申请样本,特征包括 “年龄(青年 / 中年 / 老年)”“是否有工作(是 / 否)”,目标是 “是否批准贷款”,样本分布如下:

样本 ID年龄是否有工作贷款结果
1青年拒绝
2青年拒绝
3青年批准
4中年拒绝
5中年批准
6中年批准
7老年批准
8老年批准
9老年批准
10老年批准

步骤 1:计算根节点的信息熵 \(H(D)\) 贷款结果分布:批准 7 个,拒绝 3 个。 \(H(D) = -0.7\log_2 0.7 - 0.3\log_2 0.3 \approx 0.881\)

步骤 2:计算 “年龄” 特征的条件熵 \(H(D|年龄)\)

  • 青年(3 个样本):拒绝 2,批准 1 → \(H(青年) = -2/3\log_2(2/3) -1/3\log_2(1/3) \approx 0.918\)
  • 中年(3 个样本):拒绝 1,批准 2 → \(H(中年) = -1/3\log_2(1/3) -2/3\log_2(2/3) \approx 0.918\)
  • 老年(4 个样本):拒绝 0,批准 4 → \(H(老年) = 0\)

条件熵加权求和(子集大小占比:3/10、3/10、4/10): \(H(D|年龄) = (3/10)\times0.918 + (3/10)\times0.918 + (4/10)\times0 = 0.551\)

步骤 3:计算 “年龄” 的信息增益 \(g(D,年龄) = H(D) - H(D|年龄) = 0.881 - 0.551 = 0.330\)

同理,可计算 “是否有工作” 的信息增益(最终约为 0.420)。由于 “是否有工作” 的信息增益更大,因此优先用该特征分裂根节点。

1.3 决策树的分裂停止条件

为了避免树过深导致过拟合,当满足以下任一条件时,停止分裂:

  1. 当前节点所有样本属于同一类别(熵为 0);
  2. 没有剩余特征可用于分裂;
  3. 当前节点样本数量小于预设阈值(如min_samples_split=2);
  4. 树的深度达到预设最大值(如max_depth=5)。

二、主流决策树算法对比:ID3、C4.5、CART

决策树的核心差异在于特征选择的准则,目前主流的三种算法分别是 ID3、C4.5 和 CART,它们的对比如下表:

算法特征选择准则支持特征类型支持任务优缺点
ID3信息增益最大离散特征分类优点:简单直观;缺点:偏向多值特征(如 “身份证号”)、不支持连续特征
C4.5信息增益比最大离散 + 连续分类优点:修正多值特征偏向、支持连续特征(离散化处理)、支持剪枝;缺点:计算复杂
CART基尼系数最小(分类) 平方误差最小(回归)离散 + 连续分类 + 回归优点:计算效率高(无需对数)、支持回归任务、生成二叉树;缺点:易过拟合
关键补充:基尼系数是什么?

CART 算法用 “基尼系数(Gini Index)” 衡量混乱程度,公式如下: \(Gini(D) = 1 - \sum_{x \in X} P(x)^2\) 基尼系数的含义是 “随机抽取两个样本,类别不同的概率”,取值范围 [0,1]:

  • Gini=0:所有样本类别相同(完全确定);
  • Gini=0.5:两类样本各占 50%(最混乱)。

与信息熵相比,基尼系数计算更简单(无需对数运算),在大规模数据上效率更高,因此工业界更常用。

三、Sklearn 实战:决策树分类与可视化

Sklearn 是 Python 中最常用的机器学习库,其tree模块提供了完整的决策树实现。本节将以鸢尾花分类任务为例,带你完成从模型训练到可视化的全流程。

3.1 环境准备

首先安装必要的库(若未安装):

pip install scikit-learn pandas numpy matplotlib graphviz pydotplus
  • graphvizpydotplus用于决策树可视化;
  • 若安装graphviz后报错,需在系统环境变量中添加graphvizbin目录(如 Windows:C:\Program Files\Graphviz\bin)。

3.2 完整代码流程

步骤 1:导入库与加载数据集

鸢尾花数据集包含 150 个样本,3 个类别(山鸢尾、变色鸢尾、维吉尼亚鸢尾),4 个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度):

# 导入库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.tree import export_graphviz
import pydotplus
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd# 加载数据集
iris = load_iris()
X = iris.data  # 特征:(150,4)
y = iris.target  # 标签:(150,)
feature_names = iris.feature_names  # 特征名:["sepal length (cm)", ...]
target_names = iris.target_names  # 类别名:["setosa", "versicolor", "virginica"]
步骤 2:划分训练集与测试集

train_test_split按 7:3 划分训练集和测试集,random_state=42保证结果可复现:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42  # 测试集占30%
)
print(f"训练集大小:{X_train.shape[0]},测试集大小:{X_test.shape[0]}")
步骤 3:实例化并训练决策树模型

关键参数说明:

  • criterion:特征选择准则,可选"gini"(默认)或"entropy"
  • max_depth:树的最大深度,防止过拟合(建议从 3 开始调试);
  • random_state:随机种子,保证每次训练结果一致。
# 实例化模型
dt_clf = DecisionTreeClassifier(criterion="gini",  # 用基尼系数max_depth=3,       # 最大深度3random_state=42
)# 训练模型
dt_clf.fit(X_train, y_train)
步骤 4:模型预测与评估

用测试集评估模型性能,常用指标包括准确率(Accuracy)混淆矩阵(Confusion Matrix)精确率(Precision)召回率(Recall)

# 预测测试集
y_pred = dt_clf.predict(X_test)# 1. 准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率:{accuracy:.2f}")  # 输出:模型准确率:1.00(因鸢尾花数据简单)# 2. 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
cm_df = pd.DataFrame(cm, index=target_names, columns=target_names)
plt.figure(figsize=(8,6))
sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d")
plt.title("决策树混淆矩阵")
plt.xlabel("预测类别")
plt.ylabel("真实类别")
plt.show()# 3. 分类报告(精确率、召回率、F1-score)
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=target_names))
步骤 5:决策树可视化

通过export_graphviz导出 DOT 格式文件,再用pydotplus转换为 PDF 或 PNG:

# 方法1:生成PDF文件
dot_data = export_graphviz(dt_clf,out_file=None,  # 不保存为文件,直接返回字符串feature_names=feature_names,  # 特征名class_names=target_names,     # 类别名filled=True,                  # 节点填充颜色(颜色越深,纯度越高)rounded=True                  # 节点圆角
)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("iris_decision_tree.pdf")  # 保存为PDF
print("决策树已保存为 iris_decision_tree.pdf")# 方法2:在Jupyter Notebook中直接显示(可选)
# from IPython.display import Image
# Image(graph.create_png())
步骤 6:特征重要性分析

决策树会计算每个特征对分类的贡献度(feature_importances_),可用于特征选择:

# 计算特征重要性
feature_importance = pd.DataFrame({"特征名": feature_names,"重要性": dt_clf.feature_importances_
}).sort_values(by="重要性", ascending=False)print("\n特征重要性:")
print(feature_importance)# 可视化特征重要性
plt.figure(figsize=(8,6))
sns.barplot(x="重要性", y="特征名", data=feature_importance)
plt.title("决策树特征重要性")
plt.show()

运行结果会显示:花瓣长度(petal length) 和花瓣宽度(petal width) 的重要性最高,这符合鸢尾花分类的常识(花瓣特征对类别区分更关键)。

四、决策树的优缺点与优化策略

4.1 决策树的优缺点

优点:
  1. 可解释性极强:可视化后可直接看到 “决策规则”(如 “花瓣长度≤2.45cm → 山鸢尾”);
  2. 无需数据预处理:不需要归一化、标准化,对缺失值不敏感(Sklearn 实现需处理缺失值);
  3. 训练速度快:基于贪心策略,每一步选择最优特征,时间复杂度较低;
  4. 支持多分类与回归:CART 算法可同时处理分类和回归任务。
缺点:
  1. 易过拟合:树过深时会 memorize 训练数据的噪声,泛化能力差;
  2. 不稳定:训练数据微小变化可能导致树结构巨变;
  3. 偏向多值特征:如 “用户 ID” 这类特征,每个值对应少量样本,信息增益可能虚高;
  4. 不擅长处理线性关系:对于 “X1+X2>5” 这类线性决策边界,决策树需要多层分裂才能拟合。

4.2 优化策略

1. 剪枝(Pruning):防止过拟合的核心手段
  • 预剪枝(Pre-pruning):在树生长过程中提前停止分裂,常用参数:
    • max_depth:限制树的最大深度;
    • min_samples_split:节点分裂所需的最小样本数(如≥2);
    • min_samples_leaf:叶节点所需的最小样本数(如≥5)。
  • 后剪枝(Post-pruning):先生成完整的树,再删除冗余分支(Sklearn 暂不支持,可使用tree.DecisionTreeClassifierccp_alpha参数实现成本复杂度剪枝)。
2. 集成学习:提升稳定性与泛化能力

将多个弱决策树组合成强模型,解决单一决策树不稳定的问题:

  • 随机森林(Random Forest):多棵决策树并行训练,通过投票输出结果(Sklearnensemble.RandomForestClassifier);
  • XGBoost/LightGBM:梯度提升树,串行训练多棵树,每棵树修正前一棵树的误差,精度更高(工业界常用)。
3. 特征工程:减少噪声与冗余
  • 去除低重要性特征(如通过feature_importances_筛选);
  • 对连续特征离散化(C4.5 已自动处理,但手动调整分箱可提升效果);
  • 避免使用 “用户 ID”“订单号” 等多值无意义特征。

五、总结

决策树是机器学习的 “入门基石”,它的核心是通过信息熵 / 基尼系数选择最优特征,逐步降低类别不确定性。本文从理论(信息熵、信息增益)到实践(Sklearn 训练、可视化),再到优化(剪枝、集成学习),完整覆盖了决策树的关键知识点。

对于初学者,建议先掌握:

  1. 信息熵与基尼系数的物理含义;
  2. SklearnDecisionTreeClassifier的核心参数(criterionmax_depth);
  3. 决策树可视化与特征重要性分析。
http://www.dtcms.com/a/350495.html

相关文章:

  • 【动手学深度学习】7.1. 深度卷积神经网络(AlexNet)
  • 0825 http梳理作业
  • 【慕伏白】CTFHub 技能树学习笔记 -- Web 之信息泄露
  • Linux多线程[生产者消费者模型]
  • python项目中pyproject.toml是做什么用的
  • 【Canvas与标牌】维兰德汤谷公司logo
  • Hadoop MapReduce Task 设计源码分析
  • java-代码随想录第十七天| 700.二叉搜索树中的搜索、617.合并二叉树、98.验证二叉搜索树
  • C++ STL 专家容器:关联式、哈希与适配器
  • 《微服务架构下API网关流量控制Bug复盘:从熔断失效到全链路防护》
  • 精准测试的密码:解密等价类划分,让Bug无处可逃
  • 【C语言16天强化训练】从基础入门到进阶:Day 11
  • 朴素贝叶斯算法总结
  • 互联网大厂Java面试实录:Spring Boot与微服务架构解析
  • cmd命令行删除文件夹
  • rk3566编译squashfs报错解决
  • QT5封装的日志记录函数
  • 算法练习-遍历对角线
  • 开源夜莺里如何引用标签和注解变量
  • VTK开发笔记(四):示例Cone,创建圆锥体,在Qt窗口中详解复现对应的Demo
  • 使用Cloudflare的AI Gateway代理Google AI Studio
  • 论文阅读:Code as Policies: Language Model Programs for Embodied Control
  • Redis的单线程和多线程
  • Linux_用 `ps` 按进程名过滤线程,以及用 `pkill` 按进程名安全杀进程
  • 记一次RocketMQ消息堆积
  • (二十二)深入了解AVFoundation-编辑:视频变速功能-实战在Demo中实现视频变速
  • 数字人视频创作革命!开源免费无时限InfiniteTalk ,数字人图片 + 音频一键生成无限长视频
  • ADC-工业信号采集卡-K004规格书
  • 智能电视MaxHub恢复系统
  • 【第十章】Python 文件操作深度解析:从底层逻辑到多场景实战​