当前位置: 首页 > news >正文

基于神经网络的手写数字识别系统

基于神经网络的手写数字识别系统

结合模板匹配和神经网络两种方法进行手写数字识别。这个系统包括图像预处理、特征提取、神经网络训练和可视化分析。

%% 基于神经网络的手写数字识别系统%% 清理工作区
clear; clc; close all;%% 加载手写数字数据集
% 使用MATLAB自带的手写数字数据集
digitDatasetPath = fullfile(toolboxdir('nnet'), 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...'IncludeSubfolders', true, ...'LabelSource', 'foldernames');% 显示数据集信息
disp(['数据集包含 ', num2str(numel(imds.Files)), ' 张图像']);
countEachLabel(imds)% 显示随机样本图像
figure('Name', '手写数字样本', 'Position', [100, 100, 800, 400]);
perm = randperm(numel(imds.Files), 20);
for i = 1:20subplot(4, 5, i);img = readimage(imds, perm(i));imshow(img);title(char(imds.Labels(perm(i))));
end
sgtitle('手写数字样本展示');%% 数据预处理
% 将图像调整为28x28像素并转换为灰度图
processedImages = zeros(28, 28, 1, numel(imds.Files));
for i = 1:numel(imds.Files)img = readimage(imds, i);% 转换为灰度图if size(img, 3) == 3img = rgb2gray(img);end% 调整大小img = imresize(img, [28, 28]);% 归一化处理 [0, 1]img = im2double(img);% 图像二值化img = imbinarize(img);% 存储处理后的图像processedImages(:, :, 1, i) = img;
end% 显示预处理后的图像
figure('Name', '预处理后的图像', 'Position', [100, 100, 800, 400]);
for i = 1:20subplot(4, 5, i);imshow(processedImages(:, :, 1, perm(i)));title(char(imds.Labels(perm(i))));
end
sgtitle('预处理后的手写数字');%% 创建模板匹配系统
% 为每个数字创建平均模板
templates = zeros(28, 28, 10);
digitCounts = zeros(1, 10);for i = 1:numel(imds.Files)digit = double(imds.Labels(i));templates(:, :, digit+1) = templates(:, :, digit+1) + processedImages(:, :, 1, i);digitCounts(digit+1) = digitCounts(digit+1) + 1;
end% 计算平均模板
for i = 1:10if digitCounts(i) > 0templates(:, :, i) = templates(:, :, i) / digitCounts(i);end
end% 显示模板
figure('Name', '数字模板', 'Position', [100, 100, 1000, 400]);
for i = 0:9subplot(2, 5, i+1);imshow(templates(:, :, i+1));title(['数字 ', num2str(i), ' 模板']);
end
sgtitle('模板匹配使用的数字模板');%% 模板匹配测试
% 测试模板匹配的准确率
numTest = 200; % 测试样本数量
testIndices = randperm(numel(imds.Files), numTest);
templateResults = zeros(1, numTest);
templateCorrect = 0;figure('Name', '模板匹配结果', 'Position', [100, 100, 1200, 600]);
colormap gray;for i = 1:min(20, numTest) % 只显示前20个结果idx = testIndices(i);testImg = processedImages(:, :, 1, idx);trueLabel = double(imds.Labels(idx));% 计算与每个模板的相似度(使用相关系数)correlations = zeros(1, 10);for digit = 0:9corrMatrix = corrcoef(testImg(:), templates(:, :, digit+1)(:));correlations(digit+1) = corrMatrix(1, 2);end% 选择最相似的数字[~, predLabel] = max(correlations);predLabel = predLabel - 1;% 记录结果templateResults(i) = (predLabel == trueLabel);if predLabel == trueLabeltemplateCorrect = templateCorrect + 1;end% 显示结果subplot(4, 5, i);imshow(testImg);if predLabel == trueLabeltitle(sprintf('True: %d\nPred: %d', trueLabel, predLabel), 'Color', 'g');elsetitle(sprintf('True: %d\nPred: %d', trueLabel, predLabel), 'Color', 'r');end
end% 计算准确率
templateAccuracy = templateCorrect / numTest;
fprintf('模板匹配准确率: %.2f%%\n', templateAccuracy * 100);
sgtitle(sprintf('模板匹配结果 (准确率: %.2f%%)', templateAccuracy*100));%% 准备神经网络数据
% 划分训练集和测试集 (70% 训练, 30% 测试)
[trainIdx, testIdx] = dividerand(numel(imds.Files), 0.7, 0.3);% 创建训练集
XTrain = processedImages(:, :, :, trainIdx);
YTrain = categorical(imds.Labels(trainIdx));% 创建测试集
XTest = processedImages(:, :, :, testIdx);
YTest = categorical(imds.Labels(testIdx));% 显示数据集大小
fprintf('训练集大小: %d\n', numel(trainIdx));
fprintf('测试集大小: %d\n', numel(testIdx));%% 构建神经网络模型
layers = [imageInputLayer([28 28 1], 'Name', 'input')convolution2dLayer(5, 32, 'Padding', 'same', 'Name', 'conv1')batchNormalizationLayer('Name', 'bn1')reluLayer('Name', 'relu1')maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')batchNormalizationLayer('Name', 'bn2')reluLayer('Name', 'relu2')maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')fullyConnectedLayer(128, 'Name', 'fc1')reluLayer('Name', 'relu3')dropoutLayer(0.4, 'Name', 'dropout')fullyConnectedLayer(10, 'Name', 'fc2')softmaxLayer('Name', 'softmax')classificationLayer('Name', 'output')
];% 可视化网络结构
figure('Name', '神经网络结构');
plot(layerGraph(layers));
title('卷积神经网络结构');%% 设置训练选项
options = trainingOptions('adam', ...'InitialLearnRate', 0.001, ...'MaxEpochs', 15, ...'MiniBatchSize', 128, ...'Shuffle', 'every-epoch', ...'ValidationData', {XTest, YTest}, ...'ValidationFrequency', 30, ...'Verbose', true, ...'Plots', 'training-progress', ...'ExecutionEnvironment', 'auto');%% 训练神经网络
disp('开始训练神经网络...');
net = trainNetwork(XTrain, YTrain, layers, options);
disp('神经网络训练完成!');%% 评估神经网络性能
% 在整个测试集上进行预测
YPred = classify(net, XTest);% 计算准确率
accuracy = sum(YPred == YTest) / numel(YTest);
fprintf('神经网络测试准确率: %.2f%%\n', accuracy * 100);% 混淆矩阵
figure('Name', '混淆矩阵', 'Position', [100, 100, 800, 700]);
cm = confusionmat(YTest, YPred);
confusionchart(cm, categories(YTest));
title(sprintf('混淆矩阵 (准确率: %.2f%%)', accuracy*100));%% 可视化神经网络预测结果
% 选择一些样本进行可视化
numSamplesToShow = 20;
testSampleIndices = randperm(numel(testIdx), numSamplesToShow);figure('Name', '神经网络预测结果', 'Position', [100, 100, 1200, 600]);
colormap gray;for i = 1:numSamplesToShowidx = testIdx(testSampleIndices(i));img = processedImages(:, :, 1, idx);trueLabel = char(imds.Labels(idx));predLabel = char(YPred(testSampleIndices(i)));subplot(4, 5, i);imshow(img);if strcmp(trueLabel, predLabel)title(sprintf('True: %s\nPred: %s', trueLabel, predLabel), 'Color', 'g');elsetitle(sprintf('True: %s\nPred: %s', trueLabel, predLabel), 'Color', 'r');end
end
sgtitle(sprintf('神经网络预测结果 (准确率: %.2f%%)', accuracy*100));%% 特征可视化
% 提取卷积层的激活
conv1Activations = activations(net, XTest, 'conv1');
conv2Activations = activations(net, XTest, 'conv2');% 显示卷积层特征图
sampleIdx = testSampleIndices(1); % 使用第一个测试样本
sampleImg = XTest(:, :, :, sampleIdx);figure('Name', '卷积层特征可视化', 'Position', [100, 100, 1200, 800]);% 原始图像
subplot(3, 1, 1);
imshow(sampleImg);
title('原始图像');% 第一卷积层的特征图
subplot(3, 1, 2);
montage(reshape(conv1Activations(:, :, :, sampleIdx), [28, 28]));
title('第一卷积层特征图');% 第二卷积层的特征图
subplot(3, 1, 3);
montage(reshape(conv2Activations(:, :, :, sampleIdx), [14, 14]));
title('第二卷积层特征图');%% 手写数字识别演示
% 创建一个简单的绘图界面,让用户手写数字
f = figure('Name', '手写数字识别演示', 'Position', [200, 200, 600, 500]);
ax = axes('Parent', f, 'Position', [0.1, 0.2, 0.8, 0.7]);
title('在下方区域手写一个数字');% 创建绘图区域
drawingArea = uicontrol('Style', 'text', 'Position', [60, 100, 280, 280], ...'BackgroundColor', 'white');
axes('Position', [0.1, 0.2, 0.8, 0.7]);% 初始化绘图数据
drawing = false;
lastPoint = [0, 0];
imgData = ones(280, 280) * 255; % 白色背景% 鼠标回调函数
set(gcf, 'WindowButtonDownFcn', @startDrawing);
set(gcf, 'WindowButtonUpFcn', @stopDrawing);
set(gcf, 'WindowButtonMotionFcn', @draw);% 创建按钮
uicontrol('Style', 'pushbutton', 'String', '识别', ...'Position', [100, 50, 100, 30], ...'Callback', @recognizeDigit);uicontrol('Style', 'pushbutton', 'String', '清除', ...'Position', [220, 50, 100, 30], ...'Callback', @clearDrawing);% 结果显示区域
resultText = uicontrol('Style', 'text', 'String', '结果将显示在这里', ...'Position', [100, 20, 200, 20], ...'FontSize', 12, 'FontWeight', 'bold');%% 绘图回调函数
function startDrawing(~, ~)drawing = true;
endfunction stopDrawing(~, ~)drawing = false;lastPoint = [0, 0];
endfunction draw(~, ~)if drawingcurrentPoint = get(gca, 'CurrentPoint');x = round(currentPoint(1, 1));y = round(currentPoint(1, 2));% 确保坐标在绘图区域内if x >= 1 && x <= 280 && y >= 1 && y <= 280if lastPoint(1) > 0 && lastPoint(2) > 0% 在两点之间画线lineX = linspace(lastPoint(1), x, 50);lineY = linspace(lastPoint(2), y, 50);for k = 1:50px = round(lineX(k));py = round(lineY(k));if px >= 1 && px <= 280 && py >= 1 && py <= 280% 绘制粗线for i = -2:2for j = -2:2if px+i > 0 && px+i <= 280 && py+j > 0 && py+j <= 280imgData(py+j, px+i) = 0; % 黑色endendendendendend% 更新图像imshow(imgData, 'Parent', gca);lastPoint = [x, y];endend
endfunction recognizeDigit(~, ~)% 预处理用户绘制的图像userImg = imresize(imgData, [28, 28]);userImg = imcomplement(userImg); % 反转为黑底白字userImg = im2double(userImg);% 使用神经网络进行预测[predLabel, scores] = classify(net, userImg);% 显示结果set(resultText, 'String', sprintf('识别结果: %s (置信度: %.2f%%)', char(predLabel), max(scores)*100));% 显示处理后的图像figure('Name', '预处理后的手写数字');subplot(1, 2, 1);imshow(imcomplement(imgData)); % 原始手写图像title('用户手写数字');subplot(1, 2, 2);imshow(userImg);title('预处理后的图像');
endfunction clearDrawing(~, ~)imgData = ones(280, 280) * 255; % 重置为白色背景imshow(imgData, 'Parent', gca);set(resultText, 'String', '结果将显示在这里');
end

系统功能与实现详解

1. 系统架构

本系统包含三个主要模块:

  • 模板匹配模块:创建数字模板并进行匹配识别
  • 神经网络模块:构建并训练卷积神经网络
  • 交互演示模块:允许用户手写数字进行实时识别

2. 数据处理流程

  1. 数据加载

    • 使用MATLAB自带的手写数字数据集
    • 包含0-9共10类数字图像
  2. 图像预处理

    % 转换为灰度图
    img = rgb2gray(img);% 调整大小为28x28像素
    img = imresize(img, [28, 28]);% 归一化处理
    img = im2double(img);% 图像二值化
    img = imbinarize(img);
    
  3. 模板创建

    • 对每个数字类别的图像求平均
    • 生成0-9的数字模板

3. 模板匹配算法

% 计算与每个模板的相关系数
correlations = zeros(1, 10);
for digit = 0:9corrMatrix = corrcoef(testImg(:), templates(:, :, digit+1)(:));correlations(digit+1) = corrMatrix(1, 2);
end% 选择最相似的数字
[~, predLabel] = max(correlations);

参考源码 手写体识别 模板匹配识别方法 youwenfan.com/contentcsa/78091.html

4. 神经网络架构

本系统使用了一个高效的卷积神经网络结构:

层类型参数设置输出尺寸
输入层28x28x1图像28x28x1
卷积层15x5核, 32个滤波器28x28x32
批量归一化层1-28x28x32
ReLU激活层1-28x28x32
最大池化层12x2池化, 步长214x14x32
卷积层23x3核, 64个滤波器14x14x64
批量归一化层2-14x14x64
ReLU激活层2-14x14x64
最大池化层22x2池化, 步长27x7x64
全连接层1128个神经元128
ReLU激活层3-128
Dropout层丢弃率40%128
全连接层210个神经元10
Softmax层-10
分类层--

5. 训练配置

options = trainingOptions('adam', ...'InitialLearnRate', 0.001, ...'MaxEpochs', 15, ...'MiniBatchSize', 128, ...'ValidationData', {XTest, YTest}, ...'Plots', 'training-progress');

6. 性能比较

方法准确率优点缺点
模板匹配75-85%实现简单,计算快速对形变和旋转敏感
神经网络98-99%鲁棒性强,识别精度高需要大量数据和训练时间

7. 交互演示功能

系统提供了一个绘图界面:

  1. 用户在白色画布上手写数字
  2. 点击"识别"按钮进行预测
  3. 点击"清除"按钮重置画布
  4. 显示识别结果和置信度

关键技术与创新点

  1. 多方法融合

    • 同时实现模板匹配和神经网络两种方法
    • 提供性能对比分析
  2. 特征可视化

    • 展示卷积层提取的特征图
    • 帮助理解神经网络工作原理
  3. 交互式界面

    • 实时手写识别演示
    • 显示预处理过程和识别结果
  4. 全面的评估

    • 混淆矩阵分析
    • 错误分类可视化
    • 准确率对比

系统扩展建议

  1. 数据增强

    % 添加旋转、平移、缩放等变换
    augmenter = imageDataAugmenter(...'RandRotation', [-15, 15], ...'RandXTranslation', [-3, 3], ...'RandYTranslation', [-3, 3]);
    
  2. 迁移学习

    % 使用预训练的ResNet或MobileNet
    net = resnet50;
    lgraph = layerGraph(net);
    
  3. 模型优化

    • 添加注意力机制
    • 尝试不同的网络架构
    • 使用贝叶斯优化调整超参数
  4. 实时视频识别

    • 集成摄像头输入
    • 实现实时手写数字识别
  5. 移动端部署

    • 使用MATLAB Coder生成C++代码
    • 部署到移动设备或嵌入式系统

这个系统全面展示了手写数字识别的关键技术和实现方法,通过交互式界面增强了用户体验,适用于教育演示和实际应用开发。

http://www.dtcms.com/a/302126.html

相关文章:

  • 【论文阅读53】-CNN-LSTM-滑坡风险随时间变化研究
  • 【论文阅读】Safety Alignment Should Be Made More Than Just a Few Tokens Deep
  • cacti的RCE
  • 计算机视觉---Halcon概览
  • 实用工具类分享:BeanCopyUtils 实现对象深浅拷贝高效处理
  • 墨者:SQL手工注入漏洞测试(MySQL数据库-字符型)
  • haproxy实列
  • 开源AI智能体-JoyAgent集成Deepseek
  • AI论文阅读方法+arixiv
  • 元宇宙工厂前端新形态:Three.js与WebGL实现3D产线交互的轻量化之路
  • 使用std::transform实现并发计算
  • Java 开发新人,入职后的环境搭建和配置
  • 安宝特方案丨AI算法能力开放平台:适用于人工装配质检、点检、实操培训
  • Netty中trySuccess和setSuccess的区别
  • python-内存管理
  • 【FAQ】MS Dynamics 365 Sales配置方法汇总
  • Linux中应用程序的安装于管理
  • Java面试宝典:Spring Boot
  • 基于BEKK-GARCH模型的参数估计、最大似然估计以及参数标准误估计的MATLAB实现
  • 【Linux学习】(12)环境变量
  • 自定义spring-boot-starter
  • STM32F4—电源管理器
  • 网络安全笔记
  • 图像处理第三篇:初级篇(续)—— 照明的理论知识
  • Springboot社区养老保险系统小程序
  • 基础算法思想——分治
  • 服务器防护教程 - 宝塔篇
  • 大模型应用开发1-认识大模型
  • 【Linux】编辑器vim和编译器gcc/g++
  • go‑cdc‑chunkers:用 CDC 实现智能分块 强力去重