【第二章:机器学习与神经网络概述】04.回归算法理论与实践 -(3)决策树回归模型(Decision Tree Regression)
第二章: 机器学习与神经网络概述
第四部分:回归算法理论与实践
第三节:决策树回归模型
内容:剪枝方法、回归树结构与算法实现。
决策树回归模型是一种非参数的监督学习方法,通过将特征空间划分为多个区域,在每个区域内做常数预测,适合处理非线性回归问题、特征交互明显的数据集。
一、基本原理
决策树回归以CART(Classification and Regression Trees)算法为基础,通过不断划分特征空间,构建一棵回归树:
-
每个内部节点表示对某一特征的判断;
-
每个叶节点表示一个预测值(区域内样本均值);
-
划分依据:最小化划分后区域内的均方误差(MSE)。
二、划分准则与误差计算
对样本集 D,假设以特征 的值 s 作为划分点,将样本划分为:
其目标是最小化总的平方误差:
三、剪枝策略(Pruning)
决策树容易过拟合,需通过剪枝来控制复杂度:
1. 预剪枝(Pre-Pruning)
-
在构建过程中提前停止划分:
-
达到最大深度
max_depth
-
每个节点最小样本数
min_samples_split
-
MSE 减少小于阈值
-
2. 后剪枝(Post-Pruning)
-
先生成整棵树,再从底向上剪去“收益小”的分支(如 sklearn 的
ccp_alpha
参数) -
剪枝目标:在保留预测能力的前提下降低模型复杂度
四、Python 实现示例(使用 sklearn)
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_full = DecisionTreeRegressor()
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)# 训练
reg_full.fit(X, y)
reg_pruned.fit(X, y)
reg_ccp.fit(X, y)plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False# 可视化
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, label="data", color="black")
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", linewidth=2)
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (depth=3)", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", linestyle=":")
plt.legend()
plt.title("回归树剪枝效果对比")
plt.xlabel("X")
plt.ylabel("y")
plt.grid(True)
plt.tight_layout()
plt.show()
五、优缺点分析
优点 | 缺点 |
---|---|
逻辑简单、易理解 | 容易过拟合,需要剪枝 |
可处理非线性和多维特征交互 | 对微小变化敏感,稳定性差 |
不需标准化或归一化 | 对样本数量和分布较敏感 |
可解释性强(树结构明确) | 難以推广:小数据表现好,大数据可能需集成优化 |
六、模型调参建议
参数 | 作用 | 建议 |
---|---|---|
max_depth | 限制树的最大深度 | 控制模型复杂度,避免过拟合 |
min_samples_split | 拆分内部节点所需最小样本数 | 增大可减少模型复杂度 |
min_samples_leaf | 每个叶子节点的最小样本数 | 增大有助于平滑预测结果 |
ccp_alpha | 后剪枝惩罚系数(复杂度代价剪枝) | 自动调节树结构,可结合验证集选择最佳值 |
七、典型应用场景
-
房价预测(特征离散明显)
-
电商销售量预测
-
时间序列短期预测(可结合滑窗技术)
-
特征交互复杂但不多的中小数据集建模
补充:
可视化决策树结构
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_pruned = DecisionTreeRegressor(max_depth=3)# 训练
reg_pruned.fit(X, y)plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = Falseplt.figure(figsize=(12, 6))
plot_tree(reg_pruned, filled=True, feature_names=["X"], rounded=True)
plt.title("回归树结构(max_depth=3)")
plt.show()
回归树结构图(plot_tree)
from sklearn.tree import DecisionTreeRegressor, plot_tree
import matplotlib.pyplot as plt
import numpy as np# 构造样本数据
X = np.array([[1], [2], [3], [4], [5], [6], [7], [8]])
y = np.array([5, 4.5, 4, 3.5, 3, 2.5, 2, 1.5])# 创建并训练模型
tree = DecisionTreeRegressor(max_depth=3, random_state=42)
tree.fit(X, y)# 可视化决策树结构
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plot_tree(tree, feature_names=["X"], filled=True, rounded=True)
plt.title("回归树结构图 (max_depth=3)")
plt.show()
剪枝前后预测曲线对比图
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 构造数据
rng = np.random.RandomState(0)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + rng.normal(0, 0.1, X.shape[0])# 不剪枝模型
reg_full = DecisionTreeRegressor()
reg_full.fit(X, y)# 预剪枝模型(限制最大深度)
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_pruned.fit(X, y)# 后剪枝模型(设置复杂度惩罚参数)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)
reg_ccp.fit(X, y)# 测试数据
X_test = np.linspace(0, 5, 500).reshape(-1, 1)# 可视化
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plt.scatter(X, y, label="Train Data", color="black", s=20)
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", color="blue")
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (max_depth=3)", color="green", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", color="red", linestyle=":")
plt.title("回归树剪枝对比图")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()