当前位置: 首页 > news >正文

MATLAB基于ResNet18的交通标志识别系统

1. 数据准备

  • 数据集:该数据集包含了大量标注好的交通标志图片,每类标志都有不同的样本。
  • 数据预处理:图像需要进行一些基本的预处理,如调整大小、归一化等,以适应ResNet18的输入要求。

2. 网络设计

  • 使用MATLAB自带的深度学习工具箱,可以直接加载ResNet18模型。ResNet18是一个包含18层的卷积神经网络,适用于图像分类任务。
  • 可以加载预训练的ResNet18模型,并根据交通标志数据集进行微调(fine-tuning)。微调过程中,将预训练的ResNet18模型的前几层保持不变,只修改最后的全连接层,以适应交通标志分类。

3. 训练过程

  • 划分数据集:将数据集分为训练集、验证集和测试集,通常按照80%:10%:10%的比例进行划分。
  • 定义训练选项:设置学习率、批量大小、训练轮次等参数。MATLAB的trainingOptions函数可以用于设置这些超参数。
  • 训练模型:使用trainNetwork函数对模型进行训练,调整学习率等超参数,确保模型能够收敛。
  • 
    
    clc
    clear 
    close all
    % 读取数据
    load('images.mat')
    rng(1)
    % 选取部分数据可视化 前20个
    figure
    for i=1:1:20
        subplot(4,5,i);
        imshow(images(:,:,:,(i-1)*64+7))
    end
    
    
    [M,N] = size(images(:,:,1));%图像大小
    Y = categorical(labels');                       % 标签的数据类型为categorical
    X = images;
    idx = randperm(size(images,4));   % 产生一个和数据个数一致的随机数序列
    num_train = round(0.8*length(X)); % 训练集个数,0.8表示全部数据中随机选取50%作为训练集
    
     
    % 训练集和测试集数据
    X_train = X(:,:,:,idx(1:num_train));
    X_test = X(:,:,:,idx(num_train+1:end));  %这里假设,全部数据中除了
     
    % 训练集和测试集标签
    Y_train = Y(idx(1:num_train),:);
    Y_test = Y(idx(num_train+1:end),:);
    unique(labels)
    
    %% 定义网络层
    %训练网络
    layers = resnet18Layers();
    figure
    plot(layers)
    % options = trainingOptions("sgdm", ...
    %     InitialLearnRate=0.001, ...
    %     LearnRateSchedule="piecewise", ...
    %     L2Regularization=1.0000e-04, ...
    %     MaxEpochs=20, ...
    %     MiniBatchSize=16, ...
    %     ValidationFrequency=20, ...
    %     Plots="training-progress", ...
    %     Metrics="accuracy");
    options = trainingOptions('sgdm', ...      % Adam 梯度下降算法
        'MaxEpochs',20, ...                  % 最大迭代次数 500
        'MiniBatchSize',50, ...              % 批量大小 512
        'InitialLearnRate', 5e-4, ...          % 初始学习率为 0.0005
        'LearnRateSchedule', 'piecewise', ...  % 学习率下降
        'LearnRateDropFactor', 0.1, ...        % 学习率下降因子 0.1
        'LearnRateDropPeriod', 400, ...        % 经过 400 次训练后 学习率为 0.001 * 0.1
        'L2Regularization', 0.0001, ...
        'Shuffle', 'every-epoch', ...          % 打乱数据集
        'Plots', 'training-progress', ...      % 画出曲线
        'Verbose', false);
    net_cnn = trainNetwork(X_train,Y_train,layers,options);
    
    
    % 测试
    testLabel = classify(net_cnn,X_test);
    precision = sum(testLabel==Y_test)/numel(testLabel);
    disp(['测试集分类准确率为',num2str(precision*100),'%'])
    
    save resnet18_checkpoints.mat net_cnn
    
    
    %% 
    %% 混淆矩阵
    
    fig = figure;
    cm = confusionchart(Y_test,testLabel,'RowSummary','row-normalized','ColumnSummary','column-normalized');
    
    fig_Position = fig.Position;
    fig_Position(3) = fig_Position(3)*1.5;
    fig.Position = fig_Position;
    
    
    
    
    
    

4. 模型评估

  • 训练完成后,使用验证集对模型进行评估,查看分类准确率、混淆矩阵等指标。
  • 对测试集进行测试,确保模型的泛化能力。

5. 交通标志识别

  • 使用训练好的模型对新的交通标志图像进行分类预测。可以使用classify函数对图像进行预测,得到该图像属于哪个交通标志类别。

6. 代码获取

相关文章:

  • S32K144入门笔记(十二):LPIT的解读
  • MySQL单表查询大全【SELECT】
  • .NET_Prism基本项目创建
  • Java实体类转JSON时如何避免null值变成“null“?
  • TypeORM 和 Mongoose 是两种非常流行的 ORM 工具
  • Kubernetes pod 控制器 之 Deployment
  • pytorch中的基础数据集
  • CSS引入方式、字体与文本
  • Flask中使用WTForms处理表单验证
  • 前端学习记录:解决路由缓存问题
  • 东芝2323AMW复印机安装纸盒单元后如何添加配件选项
  • 【商城实战(38)】Spring Boot:从本地事务到分布式事务,商城数据一致性的守护之旅
  • 嵌入式系统中的Board Support Package (BSP)详解:以Xilinx Zynq为例
  • AndroidStudio+Android8.0下的Launcher3 导入,编译,烧录,调试
  • BSP、设备树和HAL的关系:以Xilinx Zynq为例与PC BIOS的对比
  • nginx请求限流设置:常见的有基于 IP 地址的限流、基于请求速率的限流以及基于连接数的限流
  • 结构体定义与应用
  • 查看分析日志文件、root密码不记得了,那应该怎么解决这些问题
  • Web开发-PHP应用鉴别修复AI算法流量检测PHP.INI通用过滤内置函数
  • SGMEA: Structure-Guided Multimodal Entity Alignment
  • 人妖和美女做视频网站/直播代运营公司
  • 国土局网站建设制度/搜索引擎优化课程
  • 网站个人简介怎么做/站长工具海角
  • 个人网站下载/舆情优化公司
  • 网站建设 网站制作/seo交流群
  • 泉州网站建设培训机构/今日百度搜索风云榜