C++手撕基于ID3算法的决策树
背景
书接上回,完成了KNN之后,小H又继续学习机器学习相关内容,这一次看到的是决策树,构建一个棵树来进行分类任务,确实是非常形象呢。
概念
决策树是一种监督学习算法,常用于分类任务。ID3 算法通过计算信息增益来选择最优特征进行分裂,最终生成一棵树状结构,内部节点表示一个特征/属性,叶子节点表示一个类别。信息增益是信息论中的一个概念,用于衡量某个特征分裂数据集前后信息的减少量。
ID3 算法的核心是信息增益(Information Gain),即通过计算每个特征对数据集分类的 “贡献度”,选择贡献度最大的特征作为当前节点的划分依据,直至所有样本被正确分类或无法继续划分。
关键概念
信息熵(Entropy) 信息熵是衡量数据集 “混乱程度” 的指标,熵值越高,数据越混乱(分类越不明确)。
条件熵(Conditional Entropy) 当用特征A划分数据集D时,划分后的数据子集的平均信息熵称为条件熵。
信息增益(Information Gain) 信息增益是 “原始数据集的熵” 与 “按特征A划分后的条件熵” 的差值,衡量特征A对分类的贡献。信息增益越大,说明用特征A划分后的数据 “混乱程度降低越多”,该特征越适合作为当前节点的划分依据。
(以上几种公式就可以自行搜索)
ID3 算法流程
- 初始化:将所有训练样本作为根节点的数据集。
- 终止条件判断:
- 若当前数据集所有样本属于同一类别,将该节点标记为叶节点,返回类别。
- 若没有剩余特征可划分,将该节点标记为叶节点,返回样本中占比最高的类别(多数表决)。
- 选择最优特征:
- 计算当前数据集的信息熵H(D)。
- 对每个未使用的特征A,计算其信息增益Gain(D,A)。
- 选择信息增益最大的特征A作为当前节点的划分特征。
- 划分数据集:
- 按特征A的所有取值,将数据集拆分为多个子集
(每个取值对应一个子集)。
- 为每个子集创建子节点,递归执行步骤 2-4,直至满足终止条件。
- 按特征A的所有取值,将数据集拆分为多个子集
数据准备
就假设我们有某个数据集,列名是:Age、income、Marital Status、Label,然后若干行吧。
数据结构
我们将整个决策树封装到DecisionTree类中。
DecisionTree类
class DecisionTree{
public:// 节点类型枚举enum class NodeType{INTERNAL,LEAF};// 节点结构struct Node{NodeType type;std::string feature;std::map<std::string, std::unique_ptr<Node>> children;std::optional<std::string> label;};void fit(const Dataset& data, const std::vector<std::string>& features);std::string predict(const Example& example) const;private:std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);double entropy(const Dataset& data) const;double information_gain(const Dataset& data, const std::string& feature) const;std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;std::string get_feature_value(const Example& example, const std::string& feature) const;std::unique_ptr<Node> root_;std::string predict_helper(const Node* node, const Example& example) const;
};
方法实现
计算熵
// 计算熵
double DecisionTree::entropy(const Dataset& data) const {if (data.empty()) return 0.0;std::unordered_map<std::string, int> label_counts;for (const auto& ex : data) {label_counts[ex.label]++;}double total = static_cast<double>(data.size());double entropy_value = 0.0;for (const auto& [label, count] : label_counts) {double p = static_cast<double>(count) / total;if (p > 0) { // 避免 log(0)entropy_value -= p * std::log2(p);}}return entropy_value;
}
计算信息增益
// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const {double initial_entropy = entropy(data);double weighted_entropy = 0.0;std::map<std::string, Dataset> split_data;for (const auto& ex : data) {std::string feature_value = get_feature_value(ex, feature);split_data[feature_value].push_back(ex);}for (const auto& [value, subset] : split_data) {double weight = static_cast<double>(subset.size()) / data.size();weighted_entropy += weight * entropy(subset);}return initial_entropy - weighted_entropy;
}
选择最佳特征
// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const {double max_gain = -1.0;std::string best_feature;std::cout << " Information gains:" << std::endl;for (const auto& feature : features) {double gain = information_gain(data, feature);std::cout << " " << feature << ": " << gain << std::endl;if (gain > max_gain) {max_gain = gain;best_feature = feature;}}std::cout << " Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;return best_feature;
}
构建决策树
// 递归构建决策树
std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features) {if (data.empty()) {return nullptr;}// 统计标签std::unordered_map<std::string, int> label_counts;for (const auto& ex : data) {label_counts[ex.label]++;}// 如果所有样本都属于同一类,创建叶子节点if (label_counts.size() == 1) {auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = label_counts.begin()->first;return node;}// 如果没有更多特征可用,选择最常见的标签if (features.empty()) {auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),[](const auto& a, const auto& b) {return a.second < b.second;});auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = most_common_label->first;return node;}// 选择最佳特征std::string best_feature = choose_best_feature(data, features);auto node = std::make_unique<Node>();node->type = NodeType::INTERNAL;node->feature = best_feature;// 调试信息:显示选择的最佳特征std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;// 按最佳特征分割数据std::map<std::string, Dataset> split_data;for (const auto& ex : data) {std::string feature_value = get_feature_value(ex, best_feature);split_data[feature_value].push_back(ex);}// 创建剩余特征列表std::vector<std::string> remaining_features;for (const auto& feature : features) {if (feature != best_feature) {remaining_features.push_back(feature);}}// 递归构建子树for (const auto& [value, subset] : split_data) {node->children[value] = build_tree(subset, remaining_features);}return node;
}
预测
// 预测函数
std::string DecisionTree::predict(const Example& example) const {if (!root_) {throw std::runtime_error("Decision tree has not been trained");}return predict_helper(root_.get(), example);
}// 辅助预测函数
std::string DecisionTree::predict_helper(const Node* node, const Example& example) const {if (node->type == NodeType::LEAF) {return node->label.value();}const std::string& feature = node->feature;std::string feature_value = get_feature_value(example, feature);auto it = node->children.find(feature_value);if (it == node->children.end()) {throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");}return predict_helper(it->second.get(), example);
}
代码
主要功能就在下面,下面附上一份完整代码,包括一些辅助功能函数的实现已经在main函数里面的简单测试:
#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <optional>
#include <unordered_map>
#include <algorithm>
#include <cmath>
#include <memory>
#include <cassert>struct Example{int age;std::string income;std::string marital_status;std::string label;
};using Dataset = std::vector<Example>;class DecisionTree{
public:// 节点类型枚举enum class NodeType{INTERNAL,LEAF};// 节点结构struct Node{NodeType type;std::string feature;std::map<std::string, std::unique_ptr<Node>> children;std::optional<std::string> label;};void fit(const Dataset& data, const std::vector<std::string>& features);std::string predict(const Example& example) const;private:std::unique_ptr<Node> build_tree(const Dataset& data, const std::vector<std::string>& features);double entropy(const Dataset& data) const;double information_gain(const Dataset& data, const std::string& feature) const;std::string choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const;std::string get_feature_value(const Example& example, const std::string& feature) const;std::unique_ptr<Node> root_;std::string predict_helper(const Node* node, const Example& example) const;
};// 计算熵
double DecisionTree::entropy(const Dataset& data) const{if(data.empty()) return 0.0;std::unordered_map<std::string,int> label_counts;for(const auto& ex : data){label_counts[ex.label]++;}double total = static_cast<double>(data.size());double entropy_value = 0.0;for(const auto& [label, count] : label_counts){double p = static_cast<double>(count) / total;if(p > 0) { // 避免log(0)entropy_value -= p * std::log2(p);}}return entropy_value;
}// 获取特征值的辅助函数
std::string DecisionTree::get_feature_value(const Example& example, const std::string& feature) const{if(feature == "income"){return example.income;}else if(feature == "marital_status"){return example.marital_status;}else if(feature == "age"){if(example.age < 30){return "young";}else if(example.age >= 30 && example.age < 50){return "middle";}else{return "old";}}else{throw std::runtime_error("Unknown feature: " + feature);}
}// 计算信息增益
double DecisionTree::information_gain(const Dataset& data, const std::string& feature) const{double initial_entropy = entropy(data);double weighted_entropy = 0.0;std::map<std::string, Dataset> split_data;for(const auto& ex : data){std::string feature_value = get_feature_value(ex, feature);split_data[feature_value].push_back(ex);}for(const auto& [value, subset] : split_data){double weight = static_cast<double>(subset.size()) / data.size();weighted_entropy += weight * entropy(subset);}return initial_entropy - weighted_entropy;
}// 选择最佳特征
std::string DecisionTree::choose_best_feature(const Dataset& data, const std::vector<std::string>& features) const{double max_gain = -1.0;std::string best_feature;std::cout << " Information gains:" << std::endl;for(const auto& feature : features){double gain = information_gain(data, feature);std::cout << " " << feature << ": " << gain << std::endl;if(gain > max_gain){max_gain = gain;best_feature = feature;}}std::cout << " Best feature: " << best_feature << " (gain: " << max_gain << ")" << std::endl;return best_feature;
}std::unique_ptr<DecisionTree::Node> DecisionTree::build_tree(const Dataset& data, const std::vector<std::string>& features){if(data.empty()){return nullptr;}// 统计标签std::unordered_map<std::string, int> label_counts;for(const auto& ex : data){label_counts[ex.label]++;}// 如果所有样本都属于同一类,创建叶子节点if(label_counts.size() == 1){auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = label_counts.begin()->first;return node;}// 如果没有更多特征可用,选择最常见的标签if(features.empty()){auto most_common_label = std::max_element(label_counts.begin(), label_counts.end(),[](const auto& a, const auto& b){return a.second < b.second;});auto node = std::make_unique<Node>();node->type = NodeType::LEAF;node->label = most_common_label->first;return node;}// 选择最佳特征std::string best_feature = choose_best_feature(data, features);auto node = std::make_unique<Node>();node->type = NodeType::INTERNAL;node->feature = best_feature;// 调试信息:显示选择的最佳特征std::cout << "Selected best feature: " << best_feature << " for " << data.size() << " samples" << std::endl;// 按最佳特征分割数据std::map<std::string, Dataset> split_data;for(const auto& ex : data){std::string feature_value = get_feature_value(ex, best_feature);split_data[feature_value].push_back(ex);}// 创建剩余特征列表std::vector<std::string> remaining_features;for(const auto& feature : features){if(feature != best_feature){remaining_features.push_back(feature);}}// 递归构建子树for(const auto& [value, subset] : split_data){node->children[value] = build_tree(subset, remaining_features);}return node;
}void DecisionTree::fit(const Dataset& data, const std::vector<std::string>& features){root_ = build_tree(data, features);
}std::string DecisionTree::predict(const Example& example) const{if(!root_){throw std::runtime_error("Decision tree has not been trained");}return predict_helper(root_.get(), example);
}std::string DecisionTree::predict_helper(const Node* node, const Example& example) const{if(node->type == NodeType::LEAF){return node->label.value();}const std::string& feature = node->feature;std::string feature_value = get_feature_value(example, feature);auto it = node->children.find(feature_value);if (it == node->children.end()) {throw std::runtime_error("Feature value '" + feature_value + "' not found in tree for feature '" + feature + "'");}return predict_helper(it->second.get(), example);
}int main() {try {Dataset data = {{30, "High", "Single", "Class A"},{35, "Low", "Married", "Class B"},{40, "Medium", "Divorced", "Class A"},{25, "Low", "Single", "Class C"},{50, "High", "Married", "Class B"},{45, "Low", "Divorced", "Class A"}};std::vector<std::string> features = {"income", "marital_status", "age"};// 训练决策树std::cout << "Training decision tree..." << std::endl;DecisionTree dt;dt.fit(data, features);std::cout << "Training completed!" << std::endl;// 预测新样本Example new_example = {30, "High", "Single", ""};std::cout << "Predicting for example: Age=" << new_example.age << ", Income=" << new_example.income << ", Marital Status=" << new_example.marital_status << std::endl;std::string prediction = dt.predict(new_example);std::cout << "Prediction: " << prediction << std::endl;// 测试更多样本std::vector<Example> test_examples = {{25, "Low", "Single", ""},{45, "High", "Married", ""},{35, "Medium", "Divorced", ""}};std::cout << "\nTesting additional examples:" << std::endl;for(const auto& example : test_examples){std::string pred = dt.predict(example);std::cout << "Age=" << example.age << ", Income=" << example.income << ", Marital=" << example.marital_status << " -> " << pred << std::endl;}} catch (const std::exception& e) {std::cerr << "Error: " << e.what() << std::endl;return 1;}return 0;
}
结语
本章到这里就结束了,小H马上就要开启周末的快乐生活了,如果代码上有什么问题可以和博主联系。