matlab中随机森林算法的实现
基于MATLAB实现随机森林算法
一、核心实现代码
%% 随机森林算法实现(分类任务)
function random_forest_demo()% 1. 数据加载与预处理data = readtable('data.csv'); % 加载CSV数据X = data(:, 1:end-1); % 特征矩阵Y = data(:, end); % 标签向量% 数据标准化[X_scaled, ps_input] = mapminmax(X', 0, 1);X_scaled = X_scaled';% 2. 数据集划分cv = cvpartition(Y, 'HoldOut', 0.3);X_train = X_scaled(cv.training,:);Y_train = Y(cv.training);X_test = X_scaled(cv.test,:);Y_test = Y(cv.test);% 3. 模型训练(分类任务)numTrees = 200; % 决策树数量minLeaf = 5; % 最小叶子节点数model = TreeBagger(numTrees, X_train, Y_train, ...'Method', 'classification', ...'MinLeafSize', minLeaf, ...'OOBPrediction', 'on'); % 启用袋外预测% 4. 模型评估Y_pred = predict(model, X_test);accuracy = sum(strcmp(Y_pred, Y_test)) / numel(Y_test);fprintf('分类准确率: %.2f%%
', accuracy*100);% 5. 特征重要性分析importance = model.OOBPermutedPredictorDeltaError;figure;barh(importance);set(gca, 'YTickLabel', data.Properties.VariableNames(1:end-1));title('特征重要性排序');% 6. 预测新数据new_sample = [0.2,0.5,0.7,0.1]; % 示例新数据new_sample_scaled = mapminmax('apply', new_sample', ps_input);prediction = predict(model, new_sample_scaled);disp(['新数据预测类别: ', char(prediction)]);
end
二、关键算法解析
1. 核心函数 TreeBagger
-
参数说明:
TreeBagger(numTrees, X, Y, 'Method', 'classification', ...)
numTrees
:决策树数量(推荐50-500)Method
:任务类型(classification
/regression
)MinLeafSize
:最小叶子节点数(防止过拟合)OOBPrediction
:启用袋外误差计算
-
输出对象属性:
model.OOBPermutedPredictorDeltaError % 特征重要性评分 model.PredictorNames % 特征名称 model.ClassNames % 分类类别标签
2. 参数优化策略
% 网格搜索调参示例(分类任务)
bestAcc = 0;
for nt = [50,100,200]for ml = [3,5,7]model = TreeBagger(nt, X_train, Y_train, ...'MinLeafSize', ml, 'OOBPrediction', 'on');Y_pred = predict(model, X_test);acc = sum(strcmp(Y_pred, Y_test))/numel(Y_test);if acc > bestAccbestAcc = acc;bestParams = struct('NumTrees', nt, 'MinLeaf', ml);endend
end
disp(['最佳参数: ', num2str(bestParams.NumTrees), '树, 最小叶子: ', num2str(bestParams.MinLeaf)]);
参考仿真代码 matlab中随机森林算法的实现 www.youwenfan.com/contentcsd/65107.html
三、性能评估指标
1. 分类任务评估
% 混淆矩阵
confMat = confusionmat(Y_test, Y_pred);
disp('分类报告:');
disp(classificationReport(confMat));% ROC曲线
[~,~,~,AUC] = perfcurve(Y_test, str2double(Y_pred), 'positiveClass');
figure;
plot([0,1],[0,1],'r--');
hold on;
plot(rocData(:,1), rocData(:,2));
xlabel('假阳性率'); ylabel('真阳性率');
title(['AUC = ', num2str(AUC)]);
2. 回归任务评估
% 回归任务代码修改
model = TreeBagger(numTrees, X_train, Y_train, 'Method', 'regression');% 评估指标
rmse = sqrt(mean((Y_test - predict(model, X_test)).^2));
r2 = 1 - sum((Y_test - predict(model, X_test)).^2)/var(Y_test);
fprintf('回归性能: RMSE=%.2f, R²=%.2f
', rmse, r2);
四、高级功能实现
1. 特征重要性可视化
% 基尼重要性排序
imp = model.OOBPermutedPredictorDeltaError;
[~, idx] = sort(imp, 'descend');
figure;
barh(idx, imp(idx));
set(gca, 'YTickLabel', data.Properties.VariableNames(1:end-1));
title('特征重要性(基尼指数)');
2. 多分类任务扩展
% 加载多分类数据集(如鸢尾花)
load fisheriris
X = meas;
Y = species;% 训练模型
model = TreeBagger(100, X, Y, 'Method', 'classification');% 预测新样本
new_sample = [5.1,3.5,1.4,0.2];
predicted_class = predict(model, new_sample);
disp(['预测类别: ', char(predicted_class)]);
五、工程优化技巧
- 并行计算加速:
% 启用并行计算
if isempty(gcp('nocreate'))parpool; % 启动并行池
end
model = TreeBagger(200, X, Y, 'UseParallel', true);
- 缺失值处理:
% 自动处理缺失值
model = TreeBagger(100, X, Y, 'MissingData', 'on');
- 增量学习:
% 分阶段训练
model = TreeBagger(50, X_train1, Y_train1);
model = growTrees(model, X_train2, Y_train2); % 新增训练数据
六、应用场景示例
1. 医疗诊断(分类)
% 加载糖尿病数据集
load diabetes;
X = diabetes(:,1:8);
Y = diabetes(:,9);% 训练与评估
model = TreeBagger(150, X, Y, 'Method', 'classification');
cv = crossval(model, 'KFold', 5);
cvLoss = kfoldLoss(cv);
disp(['5折交叉验证损失: ', num2str(cvLoss)]);
2. 金融风控(回归)
% 加载贷款违约数据
data = readtable('loan_data.csv');
X = data(:,2:end-1);
Y = data.Default;% 回归模型训练
model = TreeBagger(200, X, Y, 'Method', 'regression');
Y_pred = predict(model, X);
rmse = sqrt(mean((Y - Y_pred).^2));
七、常见问题解决
- 过拟合控制:
- 增加
MinLeafSize
参数 - 使用
OOBPrediction
进行袋外验证 - 限制树的最大深度:
MaxNumSplits
- 增加
- 类别不平衡处理:
% 加权随机森林
model = TreeBagger(100, X, Y, ...'Method', 'classification', ...'ClassNames', {'class1','class2'}, ...'Prior', 'uniform'); % 或 'empirical'
- 大规模数据处理:
- 使用
RF_MexStandalone
预编译包加速 - 分块训练:
growTrees
函数增量更新模型
- 使用
八、完整工程模板
%% 随机森林工程模板
function main_rf()% 数据输入[X, Y] = load_dataset('dataset.mat');% 数据预处理[X_scaled, ps] = preprocess_data(X);% 模型训练model = train_rf(X_scaled, Y);% 模型评估evaluate_model(model, X_scaled, Y);% 模型保存save_model(model, 'random_forest_model.mat');
end
该实现方案整合了MATLAB内置的TreeBagger
函数与工程优化技巧,支持从数据预处理到模型部署的全流程。通过调整树的数量、特征子集大小等参数,可适应不同数据集的需求。对于超大规模数据,建议结合Hadoop/MAPReduce进行分布式计算。