当前位置: 首页 > news >正文

机器学习中决策树

一、简介

1.定义

        决策树是一种属性结构,其中:内部节点:代表对某个特征的判断(特征),分支:代表判断结果,叶子节点:代表最终分类结果(标签)。

2.决策树构建三步骤

        (1)特征选择:筛选对分类 / 回归贡献最大的特征

        (2)决策树生成:基于选定特征递归分裂数据集,生成初步树结构;

        (3)剪枝:缓解过拟合(初步树可能过度贴合训练数据,泛化能力差,需裁剪冗余分支)。

        注:决策树的核心优势是可解释性极强(类似 “if-else” 逻辑,易理解),无需特征标准化(如年龄和收入无需统一量级),但缺点是易过拟合,需依赖剪枝优化。

二、ID3决策树(离散特征专用)

1.核心指标:信息熵与信息增益

(1)信息熵(Entropy)

        定义:信息论中衡量数据不确定性的指标,熵越大,数据混乱度(不确定性)越高;熵越小,数据纯度越高。

        计算公式:对于数据集D,若包含k类样本,各类别占比为p1,p2...pk,则信息熵为:(对数底为 2,单位为 “比特”)

        案例验证:

                数据α(ABCDEFGH):8 类样本,每类占比(1/8),(H(α) = -8×(1/8)log_2(1/8) = 3);

                数据β(AAAABBCD):4 类样本(A:1/2,B:1/4,C:1/8,D:1/8),(H(β) = -(1/2log_2 1/2 + 1/4log_2 1/4 + 2×1/8log_2 1/8) ≈ 1.75);

                结论:(H(α) > H(β)),数据 α 更混乱。

        注:信息熵由香农(Shannon)在 1948 年《通信的数学理论》中提出,最初用于衡量通信中的 “信息不确定性”,后来被引入机器学习作为特征选择的核心指标。

(2)信息增益

        定义:特征a对数据集D的信息增益,等于 “数据集原熵H(D)” 减去 “特征a条件下的条件熵H(D|a)”,代表通过特征a分裂后,数据不确定性减少的程度。

        公式:g(D,a)=H(D)-H(D|a)

        (条件熵H(D|a):按特征α划分后的各子集熵的加权平均,权重为子集占总数据集的比例)

        案例计算(6 个样本:3A、3B,特征a分 α(4 样本:3A1B)、β(2 样本:2B)):

        原熵(H(D) = -(3/6log_2 3/6 + 3/6log_2 3/6) = 1);

        条件熵(H(D|a) = (4/6)×[-(3/4log_2 3/4 + 1/4log_2 1/4)] + (2/6)×[-(2/2log_2 2/2)] ≈ 0.54);

        信息增益g(D,a)=1-0.54=0.46。

2.ID3决策树构建流程

        (1)计算数据集中所有特征的信息增益;

        (2)选择信息增益最大的特征作为当前节点的分裂特征;

        (3)按该特征的取值将数据集拆分为子集;

        (4)对每个子集重复步骤 1-3,直到所有子集纯度达到阈值(如子集全为同一类别)或无特征可分裂。

3.典型案例:论坛客户流失分析

        数据:15 条样本(5 个流失正样本,10 个未流失负样本),特征为 “性别”“活跃度”;

        步骤:计算原熵→分别计算 “性别”“活跃度” 的信息增益→比较得出 “活跃度信息增益更大”,对流失的影响更显著。

        :ID3 的致命缺陷:①仅支持离散特征(无法处理年龄、收入等连续特征);②偏向选择取值数量多的特征(如 “用户 ID” 这类唯一值特征,信息增益极高,但无实际意义),后续 C4.5 决策树专门解决此问题。

三、C4.5决策树(ID3改进版)

1.ID3的痛点与C4.5的改进方向

        ID3痛点:偏向选择取值多的特征(如特征 b 有 6 个取值,特征 a 仅 2 个,ID3 易选 b,但 b 可能无实际预测价值);

        C4.5改进:用信息增益率替代信息增益,引入 “惩罚系数” 修正多取值特征的优势。

2.核心指标:信息增益率

        定义:信息增益率 = 信息增益 / 特征熵(特征熵:以特征a为随机变量的熵,衡量特征自身的不确定性);

        惩罚逻辑:若特征a取值多,其特征熵大,信息增益率会被 “稀释”,从而避免偏向多取值特征;

        案例验证(特征 a:2 取值,特征 b:6 取值):

                特征b的信息增益可能高,但特征熵更大(6 个唯一值,熵≈2.58),信息增益率低;

                特征a的信息增益率更高,最终被选为分裂特征。

3.C4.5的额外优势

        支持连续特征:通过 “离散化” 处理(如将年龄分为 [0-18,19-30,31+]);

        处理缺失值:用 “样本权重” 弥补(如某样本缺失 “年龄”,则按其他特征的分布分配权重);

        剪枝优化:自带后剪枝逻辑,降低过拟合风险。

        注:C4.5 的局限:①计算复杂度高(需多次计算熵和增益率);②无法处理超大数据集(需将数据全部载入内存,不支持分布式),工业界常用 CART 或随机森林(基于 CART)替代。

四、CART决策树(分类+回归双用途)

        CART(Classification and Regression Tree)是最常用的决策树模型,支持分类任务(用基尼指数)和回归任务(用平方损失),且无论特征类型,均生成二叉树(每个节点仅分 2 个分支)

1.CART分类树(预测离散类别)

        (1)核心指标:基尼指数

                定义:从数据集D中随机抽取 2 个样本,其类别标记不一致的概率,取值范围 [0,0.5],值越小,数据纯度越高;

                公式:

                基尼指数计算逻辑:选择使 “分裂后总基尼指数最小” 的特征和分裂点(如离散特征按 “是否为某取值” 分,连续特征按 “是否大于某阈值” 分)。

        (2)典型案例:是否拖欠贷款预测

                数据:10 条样本,特征为 “是否有房”“婚姻状况”“年收入”,目标为 “是否拖欠贷款”;

                计算关键:

                        [是否有房]:有房样本(3 个,全为 “不拖欠”,基尼 = 0),无房样本(7 个,4 不拖欠 3 拖欠,基尼≈0.49),总基尼指数 =(3/10)×0 +(7/10)×0.49≈0.343;

                        [婚姻状况]:按 “是否已婚” 分裂,已婚样本(4 个,全不拖欠,基尼 = 0),未婚样本(6 个,3 不拖欠 3 拖欠,基尼 = 0.5),总基尼指数 =(4/10)×0 +(6/10)×0.5=0.3;

                        [年收入]:按 97.5 为阈值分裂,总基尼指数 = 0.3;

                结论:选择 “婚姻状况(是否已婚)” 或 “年收入(97.5)” 作为分裂特征(基尼指数最小)。

2.CART回归树(预测连续值)

(1)与分类树的核心区别

维度CART 分类树CART 回归树
输出类型离散类别(如 “拖欠 / 不拖欠”)连续值(如 “房价”“收入”)
损失函数基尼指数平方损失(Loss=(f(x)-y)^2)
叶子节点输出子集中多数类别的标签子集中所有样本的均值

(2)CART回归树构建流程

        1.对特征x的取值排序,取相邻值的均值作为候选分裂点(如(x=[1,2,...,10]),候选点为 1.5,2.5,...,9.5);

        2.对每个候选点,将数据分为 “≤分裂点” 和 “>分裂点” 两个子集,计算两个子集的平方损失之和

        3.选择平方损失之和最小的分裂点作为当前特征的最优分裂点;

        4.对每个子集重复步骤 1-3,直到子集满足停止条件(如子集样本数≤阈值)。

(3)典型案例:特征x与目标y的回归

        数据:x=[1-10],(y=[5.56,5.7,5.91,6.4,6.8,7.05,8.9,8.7,9,9.05];

        步骤:

                1.候选分裂点为 1.5-9.5,计算各点平方损失,发现s=6.5时损失最小(m(s)=1.93);

                2.按s=6.5分裂为左子集(x≤6,6 个样本)和右子集(x>6,4 个样本);

                3.对左子集继续分裂,发现s=3.5时损失最小,最终生成二叉树。

        注:CART 的核心优势:①支持分类 + 回归双任务;②二叉树结构简洁,计算效率高;③可处理离散 / 连续特征,工业界应用最广(如随机森林、GBDT 等集成模型的基础均为 CART)。

五、泰塔尼克案例

1.案例背景

        数据:泰坦尼克号乘客数据(特征:Pclass(舱位等级)、Age(年龄)、Sex(性别)等,目标:Survived(是否生存));

        核心逻辑:通过决策树预测乘客是否能生存(历史事实:妇女、儿童、高舱位乘客生存率更高)。

2.实现步骤(Python+sklearn)

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, roc_curve
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# 设置中文字体,解决中文显示问题
plt.rcParams["font.family"] = ["SimHei"]
# 解决负号显示问题
plt.rcParams['axes.unicode_minus'] = False#读取数据
data = pd.read_csv('../data/train.csv')
# data.info()
#数据预处理
# df = data.isnull().sum()
# print(df)
data['Sex'] = data['Sex'].map({'female': 0, 'male': 1})
data['Age'] = data['Age'].fillna(data['Age'].mean())  # 数值列填充
# data.fillna({'Age':data['Age'].mean()}, inplace=True)
data['Embarked'] = data['Embarked'].fillna(data['Embarked'].mode()[0])
data['Embarked'] = data['Embarked'].map({'S':0, 'C':1, 'Q':2})
data.drop('Cabin', axis=1, inplace=True)
data.info()X = data[['Pclass','Sex','Age', 'SibSp', 'Parch','Fare', 'Embarked']]
y = data['Survived']
# X = pd.get_dummies(X['Embarked'], prefix='Embarked')
# print(X.shape)
#特征处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=5999,stratify=y)
# transfer = StandardScaler()
# X_train = transfer.fit_transform(X_train)
# X_test = transfer.transform(X_test)
#创建模型
es = DecisionTreeClassifier()
#训练模型
es.fit(X_train,y_train)
#模型预测
y_predict = es.predict(X_test)
#模型评估
y_proba = es.predict_proba(X_test)[:,1]
print(f"准确值:{accuracy_score(y_test, y_predict)}")
print(f"召回率:{recall_score(y_test, y_predict)}")
print(f"精准率:{precision_score(y_test, y_predict)}")
print(f"f1分数:{f1_score(y_test, y_predict)}")
print(f"AUC指标:{roc_auc_score(y_test, y_proba)}")plt.figure(figsize=(8,4))
plot_tree(es,filled=True,max_depth=30)
plt.show()

3.关键结论

        重要特征:Sex(性别)>Pclass(舱位)>Age(年龄),符合 “妇女优先、高舱位优先” 的历史事实;

        调参建议:通过max_depth(树深)、min_samples_leaf(叶子节点最小样本数)限制树的复杂度,降低过拟合。

        注:实际项目中,决策树的调参核心是控制复杂度:①max_depth不宜过大(如超过 10 易过拟合);②min_samples_split(内部节点分裂最小样本数)设为 5-10,避免小样本分裂;③ccp_alpha(成本复杂度剪枝参数)可自动剪枝,sklearn 推荐使用。

六、决策树剪枝(防止过拟合的核心手段)

1.剪枝的必要性

        决策树若不剪枝,会过度贴合训练数据(如每个叶子节点仅 1 个样本),导致泛化能力差(测试集准确率低),剪枝通过 “删除冗余分支”,保留核心逻辑,提升泛化能力。

2.两种剪枝方法对比

剪枝类型核心逻辑优点缺点
预剪枝树生成过程中,对每个节点先判断:若分裂后验证集准确率无提升,则停止分裂,标记为叶子节点训练 / 测试效率高,避免冗余分支生成可能欠拟合(过早停止分裂,错过后续有效分支)
后剪枝先生成完整决策树,再自底向上遍历非叶节点:若将子树替换为叶子节点后准确率提升,则剪枝泛化能力强,欠拟合风险低训练效率低(需生成完整树再剪枝)

3.案例:西瓜数据集剪枝

        预剪枝:生成 “脐部 =?” 节点后,若 “色泽 =?” 分裂后验证集准确率从 71.4% 降至 57.1%,则禁止分裂,保留 “脐部” 节点;

        后剪枝:先生成含 6 个内部节点的完整树,再判断 “色泽 =?”“纹理 =?” 等节点,若剪枝后准确率提升(如从 57.1%→71.4%),则剪枝。

        注:工业界常用后剪枝 + 预剪枝结合:①先用预剪枝快速生成基础树;②再用后剪枝优化局部分支;③sklearn 中DecisionTreeClassifierccp_alpha参数实现 “成本复杂度剪枝”(后剪枝的一种),通过最小化 “训练损失 +α× 树复杂度” 选择最优树。

七、三大分类决策树对比

决策树类型提出时间分支方式支持特征类型核心优势核心局限
ID31975信息增益仅离散计算简单,易理解不支持连续 / 缺失值,偏多取值特征
C4.51993信息增益率离散 + 连续解决 ID3 偏倚,支持缺失值计算复杂,不支持大数据集
CART1984基尼指数(分类)离散 + 连续支持分类 / 回归,二叉树高效对异常值敏感
平方损失(回归)

八、核心知识点梳理

1.决策树定义:内部节点是特征判断,叶子是分类结果,可明确特征重要性;

2.信息熵:熵越大,混乱度越高

3.信息熵增率:缓解多取值特征偏倚,C4.5 核心

4.基尼指数:CART 分类树核心,值越小纯度越高

5.剪枝方法:预剪枝(生成时剪)、后剪枝(生成后剪),均为防过拟合


文章转载自:

http://LyvFdPKy.rxfjg.cn
http://jmaXrUmw.rxfjg.cn
http://APBGh9ca.rxfjg.cn
http://EZHt5Xk1.rxfjg.cn
http://Tg4wBJ1k.rxfjg.cn
http://UdZS3m6n.rxfjg.cn
http://5KQXc6C5.rxfjg.cn
http://HcDQQynr.rxfjg.cn
http://TutbxeqY.rxfjg.cn
http://4uZ79arG.rxfjg.cn
http://Wel4MxVn.rxfjg.cn
http://TX6e4NF2.rxfjg.cn
http://xgAhbDhH.rxfjg.cn
http://3Q1aOLSI.rxfjg.cn
http://Fj9hRHGP.rxfjg.cn
http://kpUanIBl.rxfjg.cn
http://z4PC5e8v.rxfjg.cn
http://bv8Nwgn6.rxfjg.cn
http://kxTIfOLG.rxfjg.cn
http://veRZABfS.rxfjg.cn
http://htyS08HM.rxfjg.cn
http://2PoeFp0j.rxfjg.cn
http://wddyxMNx.rxfjg.cn
http://4io43Sr2.rxfjg.cn
http://4YbwJkIq.rxfjg.cn
http://7NTM8BPd.rxfjg.cn
http://n2OG0QeT.rxfjg.cn
http://K922lP73.rxfjg.cn
http://LoW7d8Kr.rxfjg.cn
http://juXMey0Y.rxfjg.cn
http://www.dtcms.com/a/367296.html

相关文章:

  • 算法 --- 分治(归并)
  • 深入探索 WebSocket:构建实时应用的核心技术
  • javaweb(AI)-----前端
  • C++11 类功能与包装器
  • Qt---connect建立对象间的通信链路
  • vLLM显存逆向计算:如何得到最优gpu-memory-utilization参数
  • 第15章 Jenkins最佳实践
  • 【倒计时2个月】好•真题资源+专业•练习平台=高效备赛2025初中古诗文大会
  • openEuler2403安装部署Kafbat
  • matlab 数据分析教程
  • git还原操作
  • Spring Cloud OpenFeign 核心原理
  • 【华为培训笔记】OptiX OSN 9600 设备保护专题
  • 解决 ES 模块与 CommonJS 模块互操作性的关键开关esModuleInterop
  • 解密llama.cpp:Prompt Processing如何实现高效推理?
  • 抽象与接口——Java的“武器模板”与“装备词条”
  • 数组本身的深入解析
  • Linux Centos7搭建LDAP服务(解决设置密码生成密文添加到配置文件配置后输入密码验证报错)
  • 记录一下tab梯形圆角的开发解决方案
  • java面试中经常会问到的dubbo问题有哪些(基础版)
  • illustrator-04
  • 观察者模式-红绿灯案例
  • 【LLM】FastMCP v2 :让模型交互更智能
  • Linux下开源邮件系统Postfix+Extmail+Extman环境部署记录
  • 在Anaconda下安装GPU版本的Pytorch的超详细步骤
  • 追觅科技举办2025「敢梦敢为」发布会,发布超30款全场景重磅新品
  • 从“AI炼金术”到“研发加速器”:一个研发团队的趟坑与重生实录
  • B站 XMCVE Pwn入门课程学习笔记(9)
  • 【数学建模学习笔记】机器学习回归:XGBoost回归
  • 本地部署开源数据生成器项目实战指南