当前位置: 首页 > news >正文

【第二章:机器学习与神经网络概述】04.回归算法理论与实践 -(3)决策树回归模型(Decision Tree Regression)

第二章: 机器学习与神经网络概述

第四部分:回归算法理论与实践

第三节:决策树回归模型

内容:剪枝方法、回归树结构与算法实现。

决策树回归模型是一种非参数的监督学习方法,通过将特征空间划分为多个区域,在每个区域内做常数预测,适合处理非线性回归问题、特征交互明显的数据集。


一、基本原理

决策树回归以CART(Classification and Regression Trees)算法为基础,通过不断划分特征空间,构建一棵回归树:

  • 每个内部节点表示对某一特征的判断;

  • 每个叶节点表示一个预测值(区域内样本均值);

  • 划分依据:最小化划分后区域内的均方误差(MSE)。


二、划分准则与误差计算

对样本集 D,假设以特征 x_j 的值 s 作为划分点,将样本划分为:

  • D_{\text{left}} = \{x \in D \mid x_j \leq s\}

  • D_{\text{right}} = \{x \in D \mid x_j > s\}

其目标是最小化总的平方误差:

\min_{j,s} \left[ \sum_{x_i \in D_{\text{left}}} (y_i - \bar{y}_{\text{left}})^2 + \sum_{x_i \in D_{\text{right}}} (y_i - \bar{y}_{\text{right}})^2 \right]


三、剪枝策略(Pruning)

决策树容易过拟合,需通过剪枝来控制复杂度:

1. 预剪枝(Pre-Pruning)
  • 在构建过程中提前停止划分:

    • 达到最大深度 max_depth

    • 每个节点最小样本数 min_samples_split

    • MSE 减少小于阈值

2. 后剪枝(Post-Pruning)
  • 先生成整棵树,再从底向上剪去“收益小”的分支(如 sklearn 的 ccp_alpha 参数)

  • 剪枝目标:在保留预测能力的前提下降低模型复杂度


四、Python 实现示例(使用 sklearn)

from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_full = DecisionTreeRegressor()
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)# 训练
reg_full.fit(X, y)
reg_pruned.fit(X, y)
reg_ccp.fit(X, y)plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False# 可视化
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, label="data", color="black")
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", linewidth=2)
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (depth=3)", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", linestyle=":")
plt.legend()
plt.title("回归树剪枝效果对比")
plt.xlabel("X")
plt.ylabel("y")
plt.grid(True)
plt.tight_layout()
plt.show()

 


五、优缺点分析

优点缺点
逻辑简单、易理解容易过拟合,需要剪枝
可处理非线性和多维特征交互对微小变化敏感,稳定性差
不需标准化或归一化对样本数量和分布较敏感
可解释性强(树结构明确)難以推广:小数据表现好,大数据可能需集成优化

六、模型调参建议

参数作用建议
max_depth限制树的最大深度控制模型复杂度,避免过拟合
min_samples_split拆分内部节点所需最小样本数增大可减少模型复杂度
min_samples_leaf每个叶子节点的最小样本数增大有助于平滑预测结果
ccp_alpha后剪枝惩罚系数(复杂度代价剪枝)自动调节树结构,可结合验证集选择最佳值

七、典型应用场景

  • 房价预测(特征离散明显)

  • 电商销售量预测

  • 时间序列短期预测(可结合滑窗技术)

  • 特征交互复杂但不多的中小数据集建模


补充:

可视化决策树结构
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 构造数据
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])# 模型:不剪枝 vs 预剪枝 vs 后剪枝
reg_pruned = DecisionTreeRegressor(max_depth=3)# 训练
reg_pruned.fit(X, y)plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = Falseplt.figure(figsize=(12, 6))
plot_tree(reg_pruned, filled=True, feature_names=["X"], rounded=True)
plt.title("回归树结构(max_depth=3)")
plt.show()

 


回归树结构图(plot_tree)
from sklearn.tree import DecisionTreeRegressor, plot_tree
import matplotlib.pyplot as plt
import numpy as np# 构造样本数据
X = np.array([[1], [2], [3], [4], [5], [6], [7], [8]])
y = np.array([5, 4.5, 4, 3.5, 3, 2.5, 2, 1.5])# 创建并训练模型
tree = DecisionTreeRegressor(max_depth=3, random_state=42)
tree.fit(X, y)# 可视化决策树结构
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plot_tree(tree, feature_names=["X"], filled=True, rounded=True)
plt.title("回归树结构图 (max_depth=3)")
plt.show()

 


剪枝前后预测曲线对比图
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 构造数据
rng = np.random.RandomState(0)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + rng.normal(0, 0.1, X.shape[0])# 不剪枝模型
reg_full = DecisionTreeRegressor()
reg_full.fit(X, y)# 预剪枝模型(限制最大深度)
reg_pruned = DecisionTreeRegressor(max_depth=3)
reg_pruned.fit(X, y)# 后剪枝模型(设置复杂度惩罚参数)
reg_ccp = DecisionTreeRegressor(ccp_alpha=0.01)
reg_ccp.fit(X, y)# 测试数据
X_test = np.linspace(0, 5, 500).reshape(-1, 1)# 可视化
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 6))
plt.scatter(X, y, label="Train Data", color="black", s=20)
plt.plot(X_test, reg_full.predict(X_test), label="Full Tree", color="blue")
plt.plot(X_test, reg_pruned.predict(X_test), label="Pre-Pruned (max_depth=3)", color="green", linestyle="--")
plt.plot(X_test, reg_ccp.predict(X_test), label="Post-Pruned (ccp_alpha=0.01)", color="red", linestyle=":")
plt.title("回归树剪枝对比图")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

 

http://www.dtcms.com/a/263808.html

相关文章:

  • UE5.6 官方文档笔记 [1]——虚幻编辑器界面
  • Python 单例模式与魔法方法:深度解析与实践应用
  • MySQL允许root用户远程连接
  • PDFBox + Tess4J 从PDF中提取图片OCR识别文字
  • 探秘阿里云Alibaba Cloud Linux:云时代的操作系统新宠
  • C语言学习笔记:深入解析结构体数组(附代码实践)
  • Qt QTableWidget多行多列复制粘贴
  • Android 网络全栈攻略(四)—— TCPIP 协议族与 HTTPS 协议
  • 安全左移(Shift Left Security):软件安全的演进之路
  • Spring Boot 2 多模块项目中配置文件的加载顺序
  • 智能交通信号灯
  • Django打造智能Web机器人控制平台
  • HarmonyOS应用开发高级认证知识点梳理 (三)状态管理V2装饰器核心规则
  • android车载开发之HVAC
  • 笔记本电脑怎样投屏到客厅的大电视?怎样避免将电脑全部画面都投出去?
  • 【蓝牙】Linux Qt4查看已经配对的蓝牙信息
  • 05【C++ 入门基础】内联、auto、指针空值
  • 算法-每日一题(DAY12)最长和谐子序列
  • 为Mkdocs网站添加Google广告
  • CRMEB开源商城系统Windows+IIS环境安装配置详解
  • word中一行未满但是后面有空白行
  • 每日一练:找到初始输入字符串 I
  • AbMole| H₂DCFDA(M9096;活性氧(ROS)探针)
  • MySQL索引深度解析:B+树、B树、哈希索引怎么选?
  • 凸包进阶旋转卡壳(模板题目集)
  • Window 2000 Perfectional_配置和管理FTP
  • uniapp内置蓝牙打印
  • Qt小组件 - 1(手风琴)
  • 计算机网络:【socket】【UDP】【地址转换函数】【TCP】
  • 测试第六讲-测试模型分类