当前位置: 首页 > news >正文

随机森林算法详解:Bagging思想的代表算法

文章目录

    • 一、随机森林算法介绍
      • 1.1 算法核心思想
      • 1.2 算法优势
      • 1.3 理论基础
    • 二、随机森林算法计算流程示例
      • 2.1 示例数据集
      • 2.2 决策树构建过程
      • 2.3 新样本预测流程
    • 三、sklearn中随机森林算法的API与实现原理
      • 3.1 核心API详解(sklearn.ensemble.RandomForestClassifier)
      • 3.2 底层实现原理
      • 3.3 特征重要性评估
    • 四、代码实现示例
      • 4.1 完整代码实现
      • 4.2 代码执行解析
      • 4.3运行结果
    • 五、总结与扩展


一、随机森林算法介绍

随机森林是集成学习领域中基于Bagging(Bootstrap Aggregating)思想的经典算法,其核心在于通过构建多个决策树弱学习器,利用“群体智慧”提升模型的泛化能力和鲁棒性。与传统单一决策树相比,随机森林通过双重随机化机制(样本抽样随机化和特征选择随机化)有效降低了模型方差,避免过拟合问题。

集成学习介绍:网页链接

1.1 算法核心思想

  • Bagging集成框架:通过有放回抽样(Bootstrap)生成多个不同的训练子集,每个子集训练一棵决策树,最终通过投票机制整合结果
  • 双重随机化
    • 样本层面:每次从原始数据集有放回抽取约63.2%的样本构成新训练集
    • 特征层面:每个节点分裂时仅从随机选择的k个特征中选择最优分裂特征
  • 决策树基学习器:默认使用CART(分类与回归树)作为弱学习器,不进行剪枝操作

1.2 算法优势

  • 抗噪声能力强:通过多棵树投票抵消单一树的预测偏差
  • 处理高维数据高效:自动筛选重要特征,无需复杂特征工程
  • 可解释性:可通过特征重要性评估理解各特征对预测的贡献
  • 并行计算友好:各决策树可独立训练,适合大规模数据处理

1.3 理论基础

  • 方差-偏差分解:通过增加模型多样性降低方差,保持偏差稳定
  • 大数定律:随着树的数量增加,集成模型的预测精度趋近于理论最优值
  • 多样性-准确性权衡:通过控制抽样和特征选择的随机性平衡模型多样性与一致性

二、随机森林算法计算流程示例

2.1 示例数据集

假设我们有一个二分类问题,包含5个样本,每个样本有3个特征,数据集如下:

样本编号特征1特征2特征3类别
11230
24561
37890
42341
55670

待预测新样本为:[3, 4, 5],我们将构建3棵决策树(n_estimators=3),每棵树随机选取2个特征(max_features=2)。

决策树介绍:网页链接

2.2 决策树构建过程

决策树1构建

  1. 样本抽样:有放回抽取5个样本,结果为[1, 2, 2, 4, 5](样本2被抽取两次,样本3未被抽取)
  2. 特征选择:随机选取特征1和特征2
  3. 训练过程
    • 计算特征1和特征2的分裂增益,假设最优分裂点为特征1=3
    • 分裂规则:若特征1 < 3 分类为0;否则分类为1
  4. 树结构
    在这里插入图片描述

决策树2构建

  1. 样本抽样:有放回抽取5个样本,结果为[1, 3, 3, 4, 5](样本3被抽取两次,样本2未被抽取)
  2. 特征选择:随机选取特征2和特征3
  3. 训练过程
    • 计算特征2和特征3的分裂增益,最优分裂点为特征2=5
    • 分裂规则:若特征2 < 5 分类为0;否则分类为1
  4. 树结构
    在这里插入图片描述

决策树3构建

  1. 样本抽样:有放回抽取5个样本,结果为[2, 2, 3, 4, 5](样本2被抽取两次)
  2. 特征选择:随机选取特征1和特征3
  3. 训练过程
    • 计算特征1和特征3的分裂增益,最优分裂点为特征3=6
    • 分裂规则:若特征3 < 6 分类为0;否则分类为1
  4. 树结构
    在这里插入图片描述

2.3 新样本预测流程

  1. 决策树1预测
    • 特征1=3,不满足特征1 < 3 预测为1
  2. 决策树2预测
    • 特征2=4 < 5 预测为0
  3. 决策树3预测
    • 特征3=5 < 6 预测为0
  4. 投票结果:1票支持类别1,2票支持类别0 最终预测为0

三、sklearn中随机森林算法的API与实现原理

3.1 核心API详解(sklearn.ensemble.RandomForestClassifier)

class sklearn.ensemble.RandomForestClassifier(n_estimators=100, *, criterion='gini', max_depth=None,max_features='auto',bootstrap=True,random_state=None,...
)

关键参数解析:

  • n_estimators:决策树数量,默认100。增加数量可提高精度,但会增加计算开销
  • criterion:节点分裂准则,可选’gini’(基尼系数)或’entropy’(信息熵),默认’gini’
    • 基尼系数: G i n i ( p ) = 1 − ∑ k = 1 K p k 2 Gini(p) = 1 - \sum_{k=1}^K p_k^2 Gini(p)=1k=1Kpk2,值越小表示节点越纯净
    • 信息熵: E n t r o p y ( p ) = − ∑ k = 1 K p k log ⁡ 2 p k Entropy(p) = -\sum_{k=1}^K p_k \log_2 p_k Entropy(p)=k=1Kpklog2pk,值越小表示不确定性越低
  • max_depth:决策树最大深度,默认None(完全生长)。限制深度可防止过拟合
  • max_features:每个节点分裂时考虑的最大特征数,取值规则:
    • ‘auto’/‘sqrt’: n _ f e a t u r e s \sqrt{n\_features} n_features
    • ‘log2’: log ⁡ 2 ( n _ f e a t u r e s ) \log_2(n\_features) log2(n_features)
    • None:使用全部特征
  • bootstrap:是否使用有放回抽样,默认True。若False则使用全部样本

3.2 底层实现原理

  1. 样本抽样机制

    • Bagging抽样:对于n个样本的原始数据集,每次有放回抽取n个样本,约36.8%的样本不会被抽到(称为袋外数据,OOB)
  2. 特征随机选择

    • 子空间抽样:每个节点分裂时从随机选择的k个特征中寻找最优分裂特征,k通常为 d \sqrt{d} d (d为总特征数)
    • 特征重要性计算:基于Gini不纯度减少或排列重要性评估各特征对预测的贡献
  3. 并行训练实现

    • 多线程并行:利用joblib实现决策树的并行训练,默认使用全部可用CPU核心
    • 底层优化:使用Cython优化树构建过程,支持大规模数据高效训练
  4. 预测集成机制

    • 分类任务:硬投票(多数表决),少数服从多数
    • 回归任务:简单平均各树的预测结果

3.3 特征重要性评估

# 训练模型后获取特征重要性
rfc = RandomForestClassifier()
rfc.fit(X, y)
print(rfc.feature_importances_)
  • 基于基尼不纯度减少计算,反映特征对节点分裂的贡献程度
  • 取值范围[0,1],总和为1

四、代码实现示例

4.1 完整代码实现

import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import plot_tree# =============================================
# 1. 定义数据集(5个样本,3个特征)
# =============================================
X = np.array([[1, 2, 3],  # 样本1[4, 5, 6],  # 样本2[7, 8, 9],  # 样本3[2, 3, 4],  # 样本4[5, 6, 7]   # 样本5
])y = np.array([0, 1, 0, 1, 0])  # 对应类别标签# 新样本用于预测
new_sample = np.array([[3, 4, 5]])# =============================================
# 2. 构建随机森林模型
# =============================================
rfc = RandomForestClassifier(n_estimators=3,     # 使用3棵决策树max_features=2,     # 每次分裂时随机选择2个特征random_state=44,    # 固定随机种子以保证结果可复现max_depth=1         # 设置最大深度为1,与手动构造一致
)
rfc.fit(X, y)# =============================================
# 3. 预测新样本
# =============================================
prediction = rfc.predict(new_sample)
print("新样本[3,4,5]的预测类别为:", prediction[0])# =============================================
# 4. 可视化每棵树的结构
# =============================================
estimators = rfc.estimators_  # 获取所有决策树for i, tree in enumerate(estimators):plt.figure(figsize=(8, 4))plt.title(f"随机森林中的第 {i+1} 棵决策树", fontsize=14)plot_tree(tree,feature_names=[f'Feature{j+1}' for j in range(X.shape[1])],class_names=['Class 0', 'Class 1'],filled=True,fontsize=10,rounded=True)plt.tight_layout()plt.show()

4.2 代码执行解析

  1. 数据准备:定义5个样本的特征矩阵X和类别向量y
  2. 模型初始化:设置3棵决策树,每棵树考虑2个特征
  3. 模型训练:底层自动完成3棵树的并行构建
    • 每棵树独立进行样本抽样和特征选择
    • 每棵树使用CART算法构建决策树
  4. 预测过程
    • 新样本分别输入3棵树
    • 收集各树预测结果并投票
  5. 结果输出
    • 打印最终预测类别
    • 输出各特征重要性评分

4.3运行结果

新样本[3,4,5]的预测类别为: 1分类报告(训练集):precision    recall  f1-score   support0       0.67      0.67      0.67         31       0.50      0.50      0.50         2accuracy                           0.60         5macro avg       0.58      0.58      0.58         5
weighted avg       0.60      0.60      0.60         5

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

五、总结与扩展

随机森林作为Bagging思想的典型代表,通过“分散决策+集体智慧”的策略有效提升了模型性能,在工业界和学术界均有广泛应用。其核心优势在于:

  1. 抗干扰能力:通过多树投票降低单一模型的预测偏差
  2. 自适应特性:自动处理特征交互,无需复杂特征工程
  3. 可扩展性:支持大规模数据并行训练,适合分布式计算

对于实际应用,建议:

  • 分类问题优先使用随机森林作为基线模型
  • 通过网格搜索优化n_estimatorsmax_features参数
  • 利用特征重要性进行特征筛选和解释
  • 对于极高维数据可考虑结合PCA降维

随机森林的变种如Extra Trees(极端随机树)通过进一步随机化分裂规则,在牺牲少许偏差的前提下大幅降低方差,可作为进阶选择。

http://www.dtcms.com/a/268402.html

相关文章:

  • 自存bro code java course 笔记(2025 及 2020)
  • 【Linux网络编程】Socket - UDP
  • CppCon 2018 学习:What do you mean “thread-safe“
  • Linux操作系统之文件(五):文件系统(下)
  • 数据库|达梦DM数据库安装步骤
  • 谷歌浏览器安全输入控件-allWebSafeInput控件
  • 黑布淡入淡出效果
  • Vue2 day07
  • STM32两种不同的链接配置方式
  • Python 中 ffmpeg-python 库的详细使用
  • CppCon 2018 学习:Undefined Behavior is Not an Error
  • Solidity——pure 不消耗gas的情况、call和sendTransaction区别
  • 【PyTorch】PyTorch中torch.nn模块的池化层
  • 汇编与接口技术:8259中断实验
  • Dify+Ollama+QwQ:3步本地部署,开启AI搜索新篇章
  • 1025 反转链表(附详细注释,逻辑分析)
  • 网络调式常用知识
  • 【机器学习笔记Ⅰ】1 机器学习
  • 【拓扑空间】可分性2
  • Spring Boot 集成 Thymeleaf​​ 的快速实现示例,无法渲染页面问题解决
  • 记录一点开发技巧
  • Spring Boot 3.x 整合 Swagger(springdoc-openapi)实现接口文档
  • class类和style内联样式的绑定 + 事件处理 + uniapp创建自定义页面模板
  • React Ref 指南:原理、实现与实践
  • 深度学习篇---Yolov系列
  • 远程桌面启动工具
  • Flutter 每日翻译之 Widget
  • Day53GAN对抗生成网络思想
  • MySQL主从复制与读写分离概述
  • 一文了解PMI、CSPM、软考、、IPMA、PeopleCert和华为项目管理认证