人工智能直通车系列15【机器学习基础】(决策树算法原理(ID3、C4.5、CART)决策树模型实现)
目录
决策树算法原理
1. 决策树概述
2. ID3 算法原理
3. C4.5 算法原理
4. CART 算法原理
决策树模型实现(scikit - learn)
代码实现步骤
代码解释
场景示例
贷款审批
疾病诊断
决策树算法原理
1. 决策树概述
决策树是一种基本的分类与回归方法,它通过对特征进行递归划分,将数据集分割成不同的子集,直到每个子集都尽可能属于同一类别(分类树)或具有相似的数值输出(回归树)。决策树由节点和边组成,内部节点表示一个特征上的测试,分支表示测试输出,叶节点表示类别或值。
2. ID3 算法原理
- 核心思想:ID3 算法以信息增益作为特征选择的准则,倾向于选择信息增益大的特征进行划分。
- 信息增益:信息增益是基于信息熵的概念。信息熵是衡量数据不确定性的指标,对于一个分类问题,设数据集
中第
类样本所占的比例为
,则数据集
的信息熵为:
假设使用特征对数据集
进行划分,划分后得到
个子集
,每个子集
的样本数占总样本数的比例为
,则特征
对数据集
的信息增益为:
- 算法步骤:
- 计算数据集
的信息熵
。
- 对每个特征
,计算其对数据集
的信息增益
。
- 选择信息增益最大的特征作为当前节点的划分特征。
- 根据该特征的不同取值将数据集
划分为不同的子集,对每个子集递归地重复上述步骤,直到子集的样本都属于同一类别或没有可用的特征为止。
- 计算数据集
3. C4.5 算法原理
- 核心思想:C4.5 算法是对 ID3 算法的改进,它使用信息增益比作为特征选择的准则,克服了 ID3 算法倾向于选择取值较多的特征的缺点。
- 信息增益比:信息增益比是信息增益与特征
的固有值
的比值,特征
的固有值定义为:
特征对数据集
的信息增益比为:
- 算法步骤:与 ID3 算法类似,只是在选择划分特征时使用信息增益比代替信息增益。
4. CART 算法原理
- 核心思想:CART(Classification and Regression Trees)算法既可以用于分类问题,也可以用于回归问题。对于分类问题,使用基尼指数作为特征选择的准则;对于回归问题,使用均方误差作为划分准则。
- 基尼指数:对于一个分类问题,数据集 D 的基尼指数定义为:
假设使用特征及其取值
对数据集
进行划分,得到两个子集
和
,则特征
在取值
处的基尼指数为:
- 算法步骤:
- 分类树:
- 对于每个特征
及其每个可能的取值
,计算基尼指数
。
- 选择基尼指数最小的特征和取值作为当前节点的划分条件。
- 根据划分条件将数据集
划分为两个子集,对每个子集递归地重复上述步骤,直到满足停止条件(如子集样本数小于某个阈值)。
- 对于每个特征
- 回归树:选择使得划分后两个子集的均方误差之和最小的特征和划分点进行划分。
- 分类树:
决策树模型实现(scikit - learn)
代码实现步骤
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# 1. 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 创建决策树分类器
# 使用 CART 算法(默认)
clf = DecisionTreeClassifier(random_state=42)
# 4. 训练模型
clf.fit(X_train, y_train)
# 5. 进行预测
y_pred = clf.predict(X_test)
# 6. 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy}")
# 7. 可视化决策树
plt.figure(figsize=(12, 8))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()
代码解释
- 数据加载:使用
load_iris
函数加载鸢尾花数据集,该数据集包含 4 个特征和 3 个类别。 - 数据划分:使用
train_test_split
函数将数据集划分为训练集和测试集,测试集占比 20%。 - 模型创建:创建
DecisionTreeClassifier
类的实例clf
,使用默认的 CART 算法。 - 模型训练:调用
fit
方法,使用训练集数据X_train
和对应的标签y_train
对模型进行训练。 - 模型预测:使用训练好的模型对测试集数据
X_test
进行预测,得到预测结果y_pred
。 - 模型评估:使用
accuracy_score
函数计算模型的准确率。 - 模型可视化:使用
plot_tree
函数将决策树可视化,展示决策树的结构和划分规则。
场景示例
贷款审批
银行在进行贷款审批时,可以收集客户的各种信息,如年龄、收入、信用记录等作为特征,客户是否违约作为目标变量。使用决策树模型,银行可以根据客户的信息判断客户是否有违约风险,从而决定是否批准贷款。
疾病诊断
在医疗领域,医生可以收集患者的症状、检查结果等信息作为特征,患者是否患有某种疾病作为目标变量。通过决策树模型,医生可以根据患者的信息进行初步的疾病诊断,辅助制定治疗方案。