当前位置: 首页 > news >正文

MATLAB实战:机器学习分类回归示例

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集
% 鸢尾花数据集(分类)
load fisheriris;
X_iris = meas;      % 150x4 特征矩阵
Y_iris = species;   % 150x1 类别标签

% 手写数字数据集(分类)
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...
    'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[trainImgs, testImgs] = splitEachLabel(imds, 0.7, 'randomized');

% 提取HOG特征
numTrain = numel(trainImgs.Files);
hogFeatures = zeros(numTrain, 324);  % HOG特征维度
for i = 1:numTrain
    img = readimage(trainImgs, i);
    hogFeatures(i, :) = extractHOGFeatures(img);
end
trainLabels = trainImgs.Labels;

% 汽车数据集(回归)
load carsmall;
X_car = [Weight, Horsepower, Cylinders];  % 100x3 特征矩阵
Y_car = MPG;                              % 100x1 响应变量

%% 鸢尾花分类任务
rng(1); % 设置随机种子保证可重复性
cv = cvpartition(Y_iris, 'HoldOut', 0.3);
idxTrain = training(cv);
idxTest = test(cv);

% 训练KNN模型
knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5);
knnPred = predict(knnModel, X_iris(idxTest,:));
knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest)

% 训练决策树
treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain));
treePred = predict(treeModel, X_iris(idxTest,:));
treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest)

% 训练SVM
svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain));
svmPred = predict(svmModel, X_iris(idxTest,:));
svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest)

% 混淆矩阵可视化
figure;
confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix');

%% 手写数字分类(使用KNN示例)
% 训练KNN模型
knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3);

% 处理测试集
numTest = numel(testImgs.Files);
testFeatures = zeros(numTest, 324);
testLabels = testImgs.Labels;
for i = 1:numTest
    img = readimage(testImgs, i);
    testFeatures(i, :) = extractHOGFeatures(img);
end

% 预测并评估
digitPred = predict(knnDigitModel, testFeatures);
digitAcc = sum(digitPred == testLabels) / numel(testLabels)

%% 回归任务(汽车数据)
rng(2);
cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25);
idxTrain_car = training(cv_car);
idxTest_car = test(cv_car);

% 线性回归
lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car));
lmPred = predict(lmModel, X_car(idxTest_car,:));
lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 多项式回归(二次项)
polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2');
polyPred = predict(polyModel, X_car(idxTest_car,:));
polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 可视化回归结果
figure;
scatter(Y_car(idxTest_car), lmPred, 'b');
hold on;
scatter(Y_car(idxTest_car), polyPred, 'r');
plot([0,50], [0,50], 'k--');
xlabel('Actual MPG');
ylabel('Predicted MPG');
legend('Linear', 'Polynomial', 'Ideal');
title('Regression Results Comparison');

关键函数说明:

  1. 分类模型训练:

    • fitcknn(): K近邻分类器

    • fitctree(): 决策树分类器

    • fitcecoc(): 多类SVM分类器

  2. 回归模型训练:

    • fitlm(): 线性/多项式回归

    • 'poly2'参数: 指定二次多项式项

  3. 评估指标:

    • confusionchart(): 可视化混淆矩阵

    • loss(): 计算均方误差(回归)

    • 准确率 = 正确预测数/总样本数(分类)

执行结果

鸢尾花分类准确率:
knnAcc = 0.9778
treeAcc = 0.9556
svmAcc = 0.9778

手写数字分类准确率:
digitAcc = 0.9432

回归均方误差:
lmMSE = 15.672
polyMSE = 12.845

注意事项:

  1. 特征工程

    • 手写数字使用HOG特征替代原始像素

    • 汽车数据组合多个特征(重量/马力/气缸数)

  2. 数据预处理

    • 自动处理缺失值(fitlm会排除含NaN的行)

    • 分类数据自动编码(SVM使用整数编码)

  3. 模型优化

    • 可通过crossval函数进行交叉验证

    • 使用HyperparameterOptimization参数自动调优

  4. 可视化

    • 回归结果对比图显示预测值与实际值关系

    • 混淆矩阵直观展示分类错误分布

此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。

相关文章:

  • MATLAB实战:实现数字调制解调仿真
  • gcc相关内容
  • Java中的线程池实现
  • 【图像处理入门】2. Python中OpenCV与Matplotlib的图像操作指南
  • 37. Sudoku Solver
  • uniapp与微信小程序开发平台联调无法打开IDE
  • [USACO1.5] 八皇后 Checker Challenge Java
  • 业界宽松内存模型的不统一而导致的软件问题, gcc, linux kernel, JVM
  • 【KWDB 创作者计划】_再热垃圾发电汽轮机仿真与监控系统:KaiwuDB 批量插入10万条数据性能优化实践
  • 2.4 TypeScript 中的展开运算符
  • 打造苹果级视差滚动动画:现代网页滚动动画技术详解
  • STM32入门教程——LED闪烁LED流水灯蜂鸣器
  • 【清晰教程】查看和修改Git配置情况
  • Java中Redis面试题集锦(含过期策略详解)
  • 科普:Linux `su` 切换用户后出现 `$` 提示符,如何排查和解决?
  • 论文笔记: Urban Region Embedding via Multi-View Contrastive Prediction
  • leetcode付费题 353. 贪吃蛇游戏解题思路
  • 2025年- H61-Lc169--74.搜索二维矩阵(二分查找)--Java版
  • 【技能拾遗】——家庭宽带单线复用布线与配置(移动2025版)
  • 计算机视觉与深度学习 | 基于Matlab的门禁指纹识别与人脸识别双系统实现
  • wordpress 菜单 导航/百度快照优化公司
  • 营口做网站/培训心得体会800字
  • 集团网站开发多少钱/重庆官网seo分析
  • 婚恋网站制作要多少钱/百度云搜索引擎入口盘多多
  • 漳州网站建设去博大a优/兔子bt搜索
  • 雄安 网站建设/怎么做一个网站