当前位置: 首页 > news >正文

决策树原理详解

机器学习经典算法:决策树原理详解

在机器学习领域,决策树是一种简单而强大的分类和回归算法。它通过构建树形模型来进行决策,具有易于理解和解释的特点。本文将详细介绍决策树的原理,并使用C++语言实现一个简单的决策树模型。

一、决策树的基本原理

决策树的核心思想是通过递归地将数据集划分为不同的子集,直到每个子集中的数据足够“纯净”或达到了预设的停止条件。其主要步骤包括特征选择、数据集划分以及递归构建树。

(一)特征选择

特征选择是决策树算法的关键步骤之一。每次划分数据时,算法需要选择一个最佳特征和该特征的一个阈值(对于连续特征)来划分数据。特征的选择标准通常是基于信息增益基尼指数来衡量每个特征在区分数据上的效果。

1. 计算熵 (Entropy)

熵用于衡量数据集的不确定性或纯度。熵越低,数据集越纯。其公式为:

[ H(S) = -\sum_{i=1}^{c} p_i \log_2(p_i) ]

其中,( p_i )是类别 ( i ) 在数据集 ( S ) 中的比例,( c ) 是类别的总数。

2. 计算信息增益 (Information Gain)

信息增益是通过选择一个特征进行数据划分所带来的熵的减少量。信息增益越高,表示特征的分类能力越强。其公式为:

[ IG(S, A) = H(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} H(S_v) ]

其中,( S ) 是数据集,( A ) 是待选择的特征,( S_v ) 是根据特征 ( A ) 的值 ( v ) 划分后的子集。

(二)数据集划分

在选择了最优特征后,算法需要根据该特征的取值将数据集分为两个子集。对于分类问题,这通常涉及将数据按照某个特征值的阈值分成“左子集”和“右子集”。对于连续特征,决策树会寻找一个最佳的阈值来分割数据;对于离散特征,则可以直接根据特征的不同取值进行分割。

(三)递归构建决策树

决策树是一个递归结构。当数据集被划分后,每个子集会被进一步划分,直到满足以下条件之一:

  • 数据集纯度高:如果一个节点中的数据样本全部属于同一个类别,停止划分。
  • 特征用尽:如果没有剩余的特征可以用于划分,停止划分。
  • 达到预设的停止条件:例如达到最大树深度或最小样本节点数量。

二、C++实现决策树

接下来,我们将使用C++语言实现一个简单的决策树模型。为了简化实现,我们将使用ID3算法,基于信息增益进行特征选择。

(一)数据结构定义

首先,定义决策树的节点结构:

#include <iostream>
#include <vector>
#include <cmath>
#include <map>
#include <algorithm>
using namespace std;

struct TreeNode {
    int feature_index = -1; // 特征索引
    double threshold = 0;   // 阈值
    int label = -1;         // 叶节点标签
    TreeNode* left = nullptr; // 左子树
    TreeNode* right = nullptr; // 右子树
};

(二)熵和信息增益的计算

实现熵和信息增益的计算函数:

double entropy(const vector<int>& labels) {
    map<int, int> label_counts;
    for (int label : labels) {
        label_counts[label]++;
    }
    double total = labels.size();
    double ent = 0.0;
    for (const auto& pair : label_counts) {
        double p = pair.second / total;
        ent -= p * log2(p);
    }
    return ent;
}

double informationGain(const vector<int>& labels, const vector<double>& feature_values, double threshold) {
    vector<int> left_labels, right_labels;
    for (size_t i = 0; i < labels.size(); ++i) {
        if (feature_values[i] <= threshold) {
            left_labels.push_back(labels[i]);
        } else {
            right_labels.push_back(labels[i]);
        }
    }
    double total = labels.size();
    double p_left = left_labels.size() / total;
    double p_right = right_labels.size() / total;
    return entropy(labels) - (p_left * entropy(left_labels) + p_right * entropy(right_labels));
}

(三)构建决策树

实现递归构建决策树的函数:

TreeNode* buildTree(const vector<vector<double>>& data, const vector<int>& labels, vector<int> used_features, int max_depth) {
    if (max_depth == 0 || all_of(labels.begin(), labels.end(), [&](int label) { return label == labels[0]; })) {
        return new TreeNode{ -1, 0, labels[0], nullptr, nullptr };
    }

    int best_feature = -1;
    double best_threshold = 0;
    double best_gain = -1;
    for (size_t feature = 0; feature < data[0].size(); ++feature) {
        if (find(used_features.begin(), used_features.end(), feature) != used_features.end()) {
            continue;
        }
        vector<double> feature_values;
        for (const auto& row : data) {
            feature_values.push_back(row[feature]);
        }
        sort(feature_values.begin(), feature_values.end());
        for (size_t i = 0; i < feature_values.size() - 1; ++i) {
            double threshold = (feature_values[i] + feature_values[i + 1]) / 2;
            double gain = informationGain(labels, feature_values, threshold);
            if (gain > best_gain) {
                best_gain = gain;
                best_feature = feature;
                best_threshold = threshold;
            }
        }
    }

    if (best_feature == -1) {
        return new TreeNode{ -1, 0, labels[0], nullptr, nullptr };
    }

    used_features.push_back(best_feature);
    vector<vector<double>> left_data, right_data;
    vector<int> left_labels, right_labels;
    for (size_t i = 0; i < data.size(); ++i) {
        if (data[i][best_feature] <= best_threshold) {
            left_data.push_back(data[i]);
            left_labels.push_back(labels[i]);
        } else {
            right_data.push_back(data[i]);
            right_labels.push_back(labels[i]);
        }
    }

    TreeNode* node = new TreeNode{ best_feature, best_threshold, -1, nullptr, nullptr };
    node->left = buildTree(left_data, left_labels, used_features, max_depth - 1);
    node->right = buildTree(right_data, right_labels, used_features, max_depth - 1);
    return node;
}

(四)预测函数

实现使用训练好的决策树对新样本进行预测的函数:

int predict(const vector<double>& sample, TreeNode* tree) {
    if (tree->label != -1) {
        return tree->label;
    }
    if (sample[tree->feature_index] <= tree->threshold) {
        return predict(sample, tree->left);
    } else {
        return predict(sample, tree->right);
    }
}

(五)测试决策树

使用简单的数据集测试决策树模型:

int main() {
    vector<vector<double>> data = {
        {0, 0},
        {0, 1},
        {1, 0},
        {1, 1}
    };
    vector<int> labels = {0, 1, 1, 0};

    TreeNode* tree = buildTree(data, labels, {}, 3);

    vector<double> test_sample = {1, 0};
    int prediction = predict(test_sample, tree);
    cout << "Prediction for sample [1, 0]: " << prediction << endl;

    return 0;
}

三、总结

本文详细介绍了决策树的基本原理,并使用C++语言实现了一个简单的决策树模型。通过特征选择、数据集划分和递归构建树的过程,我们能够构建一个能够对新样本进行预测的决策树模型。虽然本文的实现较为简单,但它为理解决策树算法提供了一个良好的起点。在实际应用中,可以进一步优化和扩展该模型,例如引入剪枝策略以提高模型的泛化能力。

希望本文对你有所帮助!如果你对决策树算法有更多问题,欢迎在评论区留言讨论。


相关文章:

  • 3月30号
  • Windows10 下QT社区版的安装记录
  • 在 Vue 项目中快速集成 Vant 组件库
  • 磁盘冗余阵列
  • KMeans算法案例
  • 微服务架构中的精妙设计:服务注册/服务发现-Eureka
  • MySQL执行计划分析
  • MATLAB中rmfield函数用法
  • 中国网络安全产业分析报告
  • ngx_get_options
  • 鸿蒙HarmonyOS NEXT设备升级应用数据迁移流程
  • 算法刷题-最近公共祖先-LCA
  • 元编程思想
  • MySQL8.4 NDB Cluster 集群配置安装
  • 《K230 从熟悉到...》圆形检测
  • 推荐系统(二十):TensorFlow 中的两种范式 tf.keras.Model 和 tf.estimator.Estimator
  • playwright解决重复登录问题,通过pytest夹具自动读取storage_state用户状态信息
  • 【深度学习】不管理论,入门从手写数字识别开始
  • Vue3 其它API Teleport 传送门
  • 【多线程】进阶
  • 圣耀做单网站/贵阳百度快照优化排名
  • 设计官网费用/扬州seo推广
  • wordpress直接复制文章/关键词优化排名平台
  • 专做女鞋的网站代发广州/网站seo优化多少钱
  • 品牌故事/网络优化工作内容
  • 新郑做网站/佛山网站建设十年乐云seo