【机器学习笔记 Ⅱ】11 决策树模型
决策树模型(Decision Tree)详解
决策树是一种树形结构的监督学习模型,通过一系列规则对数据进行分类或回归。其核心思想是模仿人类决策过程,通过不断提问(基于特征划分)逐步逼近答案。
1. 核心概念
- 节点类型:
- 根节点:起始问题(最佳特征划分点)。
- 内部节点:中间决策步骤(特征判断)。
- 叶节点:最终预测结果(类别或数值)。
- 分支:对应特征的取值或条件判断(如“年龄≥30?”)。
2. 构建决策树的关键步骤
(1) 特征选择
选择最优特征进行划分,常用准则:
-
分类任务:
-
回归任务:
- 均方误差(MSE)最小化:选择使子节点方差下降最多的特征。
(2) 划分停止条件
- 当前节点样本属于同一类别。
- 样本数少于预设阈值(如
min_samples_split=5
)。 - 树的深度达到最大值(
max_depth
)。
(3) 剪枝(防止过拟合)
- 预剪枝:在划分前评估,若划分不能提升性能则停止。
- 后剪枝:先生成完整树,再自底向上剪枝(如CCP方法)。
3. 决策树示例
问题:预测是否批准贷款。
特征:年龄、收入、信用评分。
树结构:
- 根节点:信用评分 ≥ 650?
- 是 → 叶节点:批准。
- 否 → 内部节点:收入 ≥ 50k?
- 是 → 叶节点:批准。
- 否 → 叶节点:拒绝。
4. 代码实现
(1) Scikit-learn分类树
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris# 加载数据
data = load_iris()
X, y = data.data, data.target# 训练模型
clf = DecisionTreeClassifier(criterion='gini', max_depth=3)
clf.fit(X, y)# 可视化树
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=data.feature_names, class_names=data.target_names)
plt.show()
(2) Scikit-learn回归树
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import fetch_california_housingdata = fetch_california_housing()
X, y = data.data, data.targetreg = DecisionTreeRegressor(max_depth=2)
reg.fit(X, y)
5. 优缺点对比
优点 | 缺点 |
---|---|
1. 可解释性强:规则直观易懂。 | 1. 容易过拟合:需剪枝或限制深度。 |
2. 无需特征缩放:对数据分布不敏感。 | 2. 不稳定:数据微小变化可能导致树结构剧变。 |
3. 处理混合类型数据:数值和类别特征均可。 | 3. 偏向多值特征:信息增益可能偏好取值多的特征。 |
6. 进阶应用
(1) 集成方法
- 随机森林(Random Forest):多棵决策树投票,降低方差。
- 梯度提升树(GBDT/XGBoost):逐步修正前序树的误差。
(2) 多输出任务
- 支持同时预测多个目标(如分类+回归)。
(3) 解释工具
- SHAP值:量化特征对单样本预测的影响。
import shap explainer = shap.TreeExplainer(clf) shap_values = explainer.shap_values(X) shap.summary_plot(shap_values, X, feature_names=data.feature_names)
7. 关键参数调优
参数 | 作用 | 常用值 |
---|---|---|
max_depth | 控制树的最大深度 | 3-10(防过拟合) |
min_samples_split | 节点分裂所需最小样本数 | 2-5 |
min_samples_leaf | 叶节点最少样本数 | 1-5 |
criterion | 分裂标准(基尼/熵/均方误差) | gini (分类) |
8. 总结
- 决策树本质:通过递归划分特征空间实现预测。
- 适用场景:
- 需要可解释性的业务(如金融风控、医疗诊断)。
- 小规模数据集或特征含义明确的任务。
- 升级方向:集成学习(如随机森林、XGBoost)提升性能。