MATLAB中进行深度学习网络训练的模型评估步骤
文章目录
- 前言
- 环境配置
- 一、基础性能评估
- 二、高级评估指标
- 三、模型解释与可视化
- 四、交叉验证与模型选择
- 五、部署前的优化
前言
在 MATLAB 中进行深度学习网络训练后的模型评估是确保模型性能和可靠性的关键环节。以下是详细的评估步骤和方法。
环境配置
MATLAB下载安装教程:https://blog.csdn.net/tyatyatya/article/details/147879353
MATLAB下载地址链接:https://pan.quark.cn/s/364584a880f7
一、基础性能评估
- 分类准确率(Classification Accuracy)
% 在测试集上进行预测
YPred = classify(net, imdsTest); % 对图像数据
YPred = predict(net, XTest); % 对数值数据% 计算准确率
YTest = imdsTest.Labels;
accuracy = mean(YPred == YTest);
fprintf('测试集准确率: %.2f%%\n', accuracy*100);
- 混淆矩阵(Confusion Matrix)
% 生成混淆矩阵
cm = confusionmat(YTest, YPred);% 可视化混淆矩阵
figure
cmplot = confusionchart(cm, categories(YTest));
cmplot.Title = '混淆矩阵';
cmplot.RowSummary = 'row-normalized'; % 显示行归一化(召回率)
cmplot.ColumnSummary = 'column-normalized'; % 显示列归一化(精确率)
- 精确率、召回率与 F1 分数
% 计算每个类别的精确率、召回率和F1分数
categories = unique(YTest);
metrics = table(categories, zeros(length(categories), 3), 'VariableNames', {'Category', 'Precision', 'Recall', 'F1Score'});for i = 1:length(categories)truePositives = cm(i,i);falsePositives = sum(cm(:,i)) - truePositives;falseNegatives = sum(cm(i,:)) - truePositives;metrics.Precision(i) = truePositives / (truePositives + falsePositives);metrics.Recall(i) = truePositives / (truePositives + falseNegatives);metrics.F1Score(i) = 2 * (metrics.Precision(i) * metrics.Recall(i)) / (metrics.Precision(i) + metrics.Recall(i));
end% 计算宏平均和微平均
macroPrecision = mean(metrics.Precision);
macroRecall = mean(metrics.Recall);
macroF1 = mean(metrics.F1Score);microPrecision = sum(diag(cm)) / sum(sum(cm));
microRecall = microPrecision; % 微平均精确率和召回率相等
microF1 = 2 * (microPrecision * microRecall) / (microPrecision + microRecall);fprintf('宏平均 F1 分数: %.4f\n', macroF1);
fprintf('微平均 F1 分数: %.4f\n', microF1);
二、高级评估指标
- ROC 曲线与 AUC 值(二分类问题)
% 获取预测概率
[YPred, scores] = classify(net, imdsTest, 'OutputAs', 'probabilities');% 计算ROC曲线
figure
for i = 1:numel(categories)[x, y, t, auc] = perfcurve(YTest, scores(:,i), categories(i));plot(x, y, 'DisplayName', [categories(i), ': AUC = ', num2str(auc, '%.3f')])
end
title('ROC曲线')
xlabel('假阳性率 (FPR)')
ylabel('真阳性率 (TPR)')
legend
grid on
- 损失函数曲线分析
% 绘制训练过程中的损失函数曲线
figure
plot(tr.TrainingLoss, 'b-', 'LineWidth', 2)
hold on
plot(tr.ValidationLoss, 'r-', 'LineWidth', 2)
title('训练与验证损失')
xlabel('训练轮次 (Epoch)')
ylabel('损失值')
legend('训练损失', '验证损失')
grid on
- 学习率调整分析
% 绘制学习率随训练轮次的变化
figure
plot(tr.LearnRate, 'LineWidth', 2)
title('学习率调整')
xlabel('训练轮次 (Epoch)')
ylabel('学习率')
grid on
三、模型解释与可视化
- 类激活映射(Class Activation Mapping, CAM)
% 计算并可视化类激活映射
I = imread('test_image.jpg');
[YPred, scores] = classify(net, I, 'OutputAs', 'probabilities');
cam = activation(net, I, 'last_conv_layer', 'OutputAs', 'image'); % 替换为实际最后卷积层名称figure
subplot(1,2,1)
imshow(I)
title('原始图像')subplot(1,2,2)
imshow(I)
hold on
h = imagesc(cam, 'AlphaData', cam);
colormap jet
axis off
title(['预测: ', string(YPred), ', 置信度: ', num2str(max(scores), '%.2f')])
colorbar
- 特征可视化
% 可视化中间层特征
I = imread('test_image.jpg');
features = activation(net, I, 'conv2_1'); % 替换为实际层名称% 可视化前16个特征图
figure
for i = 1:min(16, size(features, 3))subplot(4, 4, i)imshow(features(:,:,i), 'DisplayRange', [])title(['特征图 ', num2str(i)])
end
- 决策边界分析(二维数据)
% 生成网格点
[x1Grid, x2Grid] = meshgrid(linspace(min(XTest(:,1)), max(XTest(:,1)), 100), ...linspace(min(XTest(:,2)), max(XTest(:,2)), 100));
gridPoints = [x1Grid(:), x2Grid(:)];% 预测网格点
YPredGrid = classify(net, gridPoints);% 可视化决策边界
figure
gscatter(XTest(:,1), XTest(:,2), YTest)
hold on
contourf(x1Grid, x2Grid, reshape(YPredGrid, size(x1Grid)), 'Alpha', 0.3)
title('决策边界可视化')
legend('类别1', '类别2', '决策边界')
四、交叉验证与模型选择
- K 折交叉验证
% 设置K折交叉验证
k = 5;
cv = cvpartition(height(tbl), 'KFold', k);% 存储每折的准确率
accuracies = zeros(k, 1);% 执行交叉验证
for i = 1:kidxTrain = training(cv, i);idxTest = test(cv, i);% 训练模型net = trainNetwork(imds(idxTrain), layers, options);% 评估模型YPred = classify(net, imds(idxTest));accuracies(i) = mean(YPred == imds.Labels(idxTest));
end% 计算平均准确率和标准差
meanAccuracy = mean(accuracies);
stdAccuracy = std(accuracies);
fprintf('交叉验证准确率: %.2f%% ± %.2f%%\n', meanAccuracy*100, stdAccuracy*100);
- 模型比较
% 比较不同模型架构
models = {'resnet18', 'resnet50', 'alexnet'};
results = table(models, zeros(length(models), 1), 'VariableNames', {'Model', 'Accuracy'});for i = 1:length(models)% 加载预训练模型net = eval(models{i});% 修改网络结构% ... [省略网络修改代码] ...% 训练模型trainedNet = trainNetwork(imdsTrain, lgraph, options);% 评估模型YPred = classify(trainedNet, imdsTest);results.Accuracy(i) = mean(YPred == YTest);
end% 显示比较结果
results = sortrows(results, 'Accuracy', 'descend');
disp(results);
五、部署前的优化
- 模型量化
% 量化模型以减小尺寸和加速推理
quantizedNet = quantizeNetwork(net, 'WeightPrecision', 8, 'ActivationPrecision', 8);% 评估量化模型
YPredQuantized = classify(quantizedNet, imdsTest);
accuracyQuantized = mean(YPredQuantized == YTest);
fprintf('量化模型准确率: %.2f%%\n', accuracyQuantized*100);
- 剪枝(Pruning)
% 对模型进行剪枝
prunedNet = pruneNetwork(net, 'Percentage', 50); % 剪枝50%的连接% 微调剪枝后的模型
optionsFineTune = trainingOptions('sgdm', ...'InitialLearnRate', 0.0001, ...'MaxEpochs', 3);
prunedNet = trainNetwork(imdsTrain, prunedNet, optionsFineTune);% 评估剪枝模型
YPredPruned = classify(prunedNet, imdsTest);
accuracyPruned = mean(YPredPruned == YTest);
fprintf('剪枝模型准确率: %.2f%%\n', accuracyPruned*100);