当前位置: 首页 > news >正文

决策树的学习(二)

一、整体框架

本 PPT 聚焦机器学习中的决策树算法,围绕 “核心算法(ID3、C4.5、CART)→ 特殊问题(连续值处理)→ 优化策略(剪枝)→ 代码实现→ 课堂练习” 展开,系统补充决策树的进阶知识,解决基础算法的局限性并提供工程化落地思路。

二、核心决策树算法

(一)ID3 算法

  1. 核心衡量标准:信息增益

定义:某属性划分数据后,带来的 “熵减少量”(即纯度提升程度)。

选择逻辑:信息增益越大,该属性划分后数据纯度越高,优先作为划分属性。

  1. 致命局限性:对 “可取值数目较多的属性” 存在偏好。例如数据中的 “ID” 属性(每个 ID 唯一),用其划分会使每个子集仅含 1 个样本,信息增益最大,但无实际预测意义,易导致过拟合。
  1. 示例数据集:7 条 “是否出去玩” 的记录,特征包括天气(晴 / 阴 / 雨)、温度(高 / 适中 / 低)、湿度(高 / 正常)、是否多云(是 / 否),用于演示属性选择逻辑。

(二)C4.5 算法(解决 ID3 属性偏好问题)

  1. 核心改进:用 “信息增益率” 替代信息增益

公式:信息增益率 = 信息增益 ÷ 该属性自身的熵。

原理:属性自身熵会随 “可取值数目” 增加而增大(如 “ID” 属性自身熵极高),通过除法抵消高取值数属性的优势,避免 ID3 的偏好问题。

  1. 适用场景:需平衡 “属性区分度” 与 “取值数影响” 的分类任务,同样基于上述 7 条 “是否出去玩” 数据集演示计算逻辑,修正 ID3 的局限性。

(三)CART 算法(分类与回归通用)

  1. 核心衡量标准:基尼指数(Gini (D))

定义:反映从数据集 D 中随机抽取 2 个样本,类别标记不一致的概率,公式为 \(Gini(D) = 1 - \sum_{k=1}^{n} p_k^2\)(\(p_k\)为第 k 类样本在 D 中的占比)。

规律:\(p_k\)越大(数据纯度越高),Gini (D) 越小;当所有样本属于同一类时,Gini (D)=0(纯度最高)。

  1. 特点:既支持分类任务(用基尼指数衡量纯度),也支持回归任务(用平方误差衡量损失),是工程中常用的决策树算法。

三、特殊问题:连续值处理

当特征为连续值(如收入、年龄)时,无法直接按离散值划分,需通过 “离散化” 转化,核心方法为贪婪算法,步骤如下:

  1. 排序:将连续特征的所有取值按升序排列。例如 “应税收入(Taxable Income)” 样本值排序为:60K、70K、75K、85K、90K、95K、100K、120K、125K、220K。
  1. 确定分界点:若对连续值做 “二分划分”,则分界点数量 = 取值个数 - 1(如 10 个取值对应 9 个分界点,取相邻两个值的中间值,如 65K、72.5K 等)。
  1. 选择最优分界点:遍历所有可能的分界点,用信息增益(ID3/C4.5)或基尼指数(CART)计算划分效果,选择最优分界点(如示例中 “TaxIn<=80” 或 “TaxIn<=97.5”),完成连续值到离散值的转化。

四、决策树优化:剪枝策略(解决过拟合)

(一)剪枝的必要性

决策树理论上可通过不断划分,将训练数据 “完全分开”,但会导致模型过度拟合训练数据(对噪声敏感,泛化能力差),因此需通过剪枝降低复杂度。

(二)预剪枝

  1. 定义:“边构建决策树边剪枝”,在树的生长过程中提前停止分支,是工程中更实用的策略。
  1. 剪枝依据(停止条件)

限制树的最大深度(如最多 5 层);

限制叶子节点最少样本数(如叶子节点需含≥10 个样本才继续分支);

限制信息增益 / 基尼指数阈值(如划分后增益低于 0.1 则停止)。

  1. 优势:计算成本低,可避免构建复杂的全量树;劣势:可能因 “提前停止” 导致欠拟合(未充分学习数据规律)。

(三)后剪枝

  1. 定义:“先构建完整决策树,再从叶子节点向根节点回溯剪枝”,保留更贴合数据规律的分支。
  1. 剪枝衡量标准:损失函数

公式:最终损失 = 树自身的基尼系数(或熵) + α× 叶子节点数量(α 为正则化系数)。

α 的影响:

α 越大:正则化越强,优先减少叶子节点数量(树越简单),虽降低过拟合风险,但可能导致模型精度下降;

α 越小:更关注模型精度,叶子节点数量多,过拟合风险较高。

  1. 示例验证:以 “好瓜 / 坏瓜” 分类为例,通过验证集精度判断是否剪枝:

原分支 “色泽 =?”:剪枝前精度 57.1%,剪枝后 71.4%→决策剪枝;

原分支 “纹理 =?”:剪枝前精度 42.9%,剪枝后 57.1%→决策剪枝。

  1. 优势:泛化能力更强,不易欠拟合;劣势:需构建完整树,计算成本高于预剪枝。

五、决策树代码实现(基于 Python)

核心调用sklearn库的DecisionTreeClassifier()类,关键参数及含义如下表:

参数名

取值范围

核心作用

criterion

gini(基尼指数)、entropy(信息熵)

定义属性划分的衡量标准(CART 用 gini,ID3/C4.5 思路用 entropy)

splitter

best(所有特征找最优切分点)、random(部分特征找切分点)

控制切分点选择的随机性,random 可降低过拟合风险(类似随机森林思路)

max_features

None(用所有特征)、log2(log₂(特征数))、sqrt(根号下特征数)、整数 N

限制每次划分时考虑的特征数量,避免冗余特征影响,提升效率

max_depth

整数(如 5-20)或 None

限制树的最大深度,深度越大越易过拟合,推荐 5-20 之间平衡精度与泛化能力

六、课堂练习

任务:使用决策树算法对 “泰坦尼克号幸存者” 数据集进行预测,核心目标是:

  1. 实践特征选择(如乘客年龄、性别、舱位等级等);
  1. 调整DecisionTreeClassifier()参数(如criterion、max_depth);
  1. 结合剪枝思路优化模型,提升预测准确率(如避免过拟合)。代码如下:
    # 泰坦尼克号幸存者预测 - 决策树完整流程# 1. 导入必要的库
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
    from sklearn.metrics import accuracy_score# 2. 数据加载与初步查看
    # 使用seaborn内置数据集(无需本地文件)
    data = sns.load_dataset('titanic')
    # 查看数据基本信息
    data.info()
    # 查看前5行数据
    data.head()# 3. 数据预处理
    # 删除缺失值过多的列和无关特征
    data.drop(["Cabin", "Name", "Ticket"], inplace=True, axis=1)
    # 处理缺失值:年龄用均值填充
    data["Age"] = data["Age"].fillna(data["Age"].mean())
    # 删除剩余含有缺失值的行
    data = data.dropna()# 4. 分类变量编码
    # 将Embarked三分类变量转换为数值型
    labels = data["Embarked"].unique().tolist()
    data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
    # 将Sex二分类变量转换为0和1(男=1,女=0)
    data["Sex"] = (data["Sex"] == "male").astype("int")# 5. 特征与标签分离
    # 特征数据(排除Survived列)
    X = data.iloc[:, data.columns != "Survived"]
    # 标签数据(仅Survived列)
    y = data.iloc[:, data.columns == "Survived"]# 6. 划分训练集和测试集
    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.3, random_state=42)# 修正索引(确保索引连续)
    for i in [Xtrain, Xtest, Ytrain, Ytest]:i.index = range(i.shape[0])# 7. 初步模型训练与评估
    # 创建决策树模型
    clf = DecisionTreeClassifier(random_state=25)
    # 训练模型
    clf.fit(Xtrain, Ytrain)
    # 测试集评估
    score = clf.score(Xtest, Ytest)
    print(f"测试集准确率: {score:.4f}")# 8. 交叉验证评估
    cv_score = cross_val_score(clf, X, y, cv=10).mean()
    print(f"10折交叉验证平均准确率: {cv_score:.4f}")# 9. 决策树深度调优与可视化
    # 存储不同深度下的得分
    tr_scores = []
    te_scores = []# 测试深度1到10的决策树
    for i in range(10):clf = DecisionTreeClassifier(random_state=20,max_depth=i + 1,criterion="entropy")clf.fit(Xtrain, Ytrain)# 训练集得分tr_scores.append(clf.score(Xtrain, Ytrain))# 交叉验证得分te_scores.append(cross_val_score(clf, X, y, cv=10).mean())# 输出最佳得分
    print(f"最佳训练集得分: {max(tr_scores):.4f}")
    print(f"最佳交叉验证得分: {max(te_scores):.4f}")# 可视化得分随深度变化的趋势
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, 11), tr_scores, color="red", label="训练集得分")
    plt.plot(range(1, 11), te_scores, color="blue", label="交叉验证得分")
    plt.xticks(range(1, 11))
    plt.xlabel("决策树深度")
    plt.ylabel("准确率")
    plt.title("不同深度下的模型性能")
    plt.legend()
    plt.show()# 10. 网格搜索超参数优化
    # 生成参数候选值
    gini_thresholds = np.linspace(0, 0.5, 20)# 定义超参数网格
    parameters = {'splitter': ('best', 'random'),'criterion': ("gini", "entropy"),'max_depth': [*range(1, 10)],'min_samples_leaf': [*range(1, 50, 5)],'min_impurity_decrease': [*np.linspace(0, 0.5, 20)]
    }# 创建网格搜索对象
    GS = GridSearchCV(estimator=DecisionTreeClassifier(random_state=25),param_grid=parameters,cv=10
    )# 执行网格搜索
    GS.fit(Xtrain, Ytrain)# 输出最优参数和最佳得分
    print("最优超参数组合:", GS.best_params_)
    print("最优参数下的交叉验证得分:", GS.best_score_)# 11. 使用最优模型进行预测
    best_clf = GS.best_estimator_
    y_pred = best_clf.predict(Xtest)
    print(f"最优模型在测试集上的准确率: {accuracy_score(Ytest, y_pred):.4f}")
    

http://www.dtcms.com/a/338759.html

相关文章:

  • MCP(模型上下文协议):是否是 AI 基础设施中缺失的标准?
  • jsPDF 不同屏幕尺寸 生成的pdf不一致,怎么解决
  • Ansible 中的文件包含与导入机制
  • java17学习笔记-Deprecate the Applet API for Removal
  • C语言基础:(十八)C语言内存函数
  • 连接远程服务器上的 jupyter notebook,解放本地电脑
  • 计算机毕设推荐:痴呆症预测可视化系统Hadoop+Spark+Vue技术栈详解
  • 生成式AI的能力边界与职业重构:从“百科实习生“到人机协作增强器
  • 人工智能学派简介
  • 当宠物机器人装上「第六感」:Deepoc 具身智能如何重构宠物机器人照看逻辑
  • Python字符串变量插值深度解析:从基础到高级工程实践
  • 安装DDNS-go
  • 【部署相关】DockerKuberbetes常用命令大全(速查+解释)
  • 便携式科研土壤监测仪:让土壤检测走进 “轻时代”
  • 大数据MapReduce架构:分布式计算的经典范式
  • 【MySQL】--- 库表操作
  • Python + 淘宝 API 开发:自动化采集商品数据的完整流程​
  • Redis(11)如何通过命令行操作Redis?
  • 对象创建过程
  • 《算法导论》第 32 章 - 字符串匹配
  • 大数据云原生是什么
  • 中国技术引领人工心脏变革——欧洲心脏与心力衰竭大会特别报道
  • 思科语音系统简要了解
  • 【科研绘图系列】R语言绘制多种小提琴和云雨图
  • 期权小故事:王安石变法与期权
  • electron进程间通信- 渲染进程与主进程双向通信
  • GitHub 热榜项目 - 日榜(2025-08-19)
  • 从现场到云端的“通用语”:Kepware 在工业互联中的角色、使用方法与本土厂商(以胡工科技为例)的差异与优势
  • AiPPT怎么样?好用吗?
  • Ubuntu22系统上源码部署LLamaFactory+微调模型 教程【亲测成功】