基于 PyTorch 模型训练优化、FastAPI 跨域配置与 Vue 响应式交互的手写数字识别
0 序言
本文围绕手写数字识别项目展开,涵盖前端交互(Vue)、后端接口(FastAPI)、CNN模型训练(PyTorch)全流程,把之前学习过的知识综合运用起来。内容包含环境搭建、代码实现、操作步骤及问题解决,借助该项目来掌握前后端分离项目开发、MNIST数据集应用、LeNet5模型训练与部署,获取可复用的图像分类项目流程,快速复现或扩展类似项目。
1 项目基础与环境准备
1.1 项目介绍与目标
1.1.1 项目介绍
手写数字识别是计算机视觉入门经典任务,基于MNIST数据集(含6万训练样本、1万测试样本,每个样本为28×28灰度图,对应0-9数字),采用LeNet5卷积神经网络(CNN)实现分类,架构为前端交互+后端预测+模型支撑
的前后端分离模式。
1.1.2 项目目标
- 前端:提供画布供用户手写数字,完成图像预处理(缩放、灰度转换),发起后端请求并展示结果。
- 后端:接收前端图像,通过预训练LeNet5模型预测数字,返回结果。
- 整体:实现端到端识别,准确率达98%以上,掌握全流程开发逻辑。
具体的流程可以参考下图:
1.2 开发环境准备
1.2.1 基础环境要求
- 编程语言:Python 3.8+(后端+模型训练)、JavaScript(前端Vue)
- 运行环境:Node.js 16+(Vue项目依赖管理)、Python虚拟环境
1.2.2 依赖库安装
1.2.2.1 Python依赖(后端+模型)
通过pip
安装核心库,命令如下:
# 后端框架与网络请求
pip install fastapi uvicorn
# PyTorch核心(含CPU版本,GPU版本需替换命令)
pip install torch torchvision
# 图像处理与数据处理
pip install pillow numpy
# 前端请求库(Vue侧后续安装)
这里还有个要注意的点就是,如果电脑里有多个python环境,在这里用pip下载最好指定一下,不然会默认用全局的python环境去下载。
比如:
D:\Python\Scripts\pip3.12.exe install [安装包]
1.2.2.2 Vue依赖(前端)
进入前端项目目录(mnist-frontend
),通过npm
安装:
# 初始化Vue项目(若未创建)
npm create vue@latest mnist-frontend
# 进入目录并安装axios(请求后端)
cd mnist-frontend
npm install axios
1.2.3 项目目录结构
参考实际文件路径(D:\ProjectPython\DNN_CNN
),规范结构如下(便于后续复用):
DNN_CNN/ # 项目根目录
├─ mnist-frontend/ # 前端Vue项目
│ ├─ src/
│ │ ├─ App.vue # 前端核心文件(模板+逻辑+样式)
│ │ ├─ main.js # Vue入口文件
│ │ └─ style.css # 全局样式(本项目用组件内联样式)
│ └─ package.json # Vue依赖配置
├─ CNN_Proj.py # 模型训练脚本(生成权重文件)
├─ main.py # 后端FastAPI服务脚本
├─ LeNet5_mnist.pth # 预训练模型权重(训练后生成)
└─ dataset/ # MNIST数据集(训练脚本自动下载)
2 前端实现(Vue)
2.1 前端核心功能定位
前端是用户交互入口,需解决如何让用户输入数字
、如何将输入转为模型可识别格式
和如何与后端通信
三个核心问题,最终实现绘制→预处理→请求→展示的这一闭环。
2.2 模板结构设计(App.vue
的<template>
)
模板需包含交互组件+反馈组件
,结构如下:
<template><div class="container"><h1>手写数字识别</h1><!-- 1. 主画布(用户绘制数字) --><canvas ref="canvas" width="280" height="280" @mousedown="startDrawing" @mousemove="draw" @mouseup="stopDrawing"@mouseleave="stopDrawing"></canvas><!-- 2. 调试画布(预览28×28预处理图像,便于排查问题) --><div class="debug-section" v-show="showDebug"><h3>预处理后图像(28x28 放大)</h3><canvas ref="debugCanvas" width="280" height="280"></canvas><p class="debug-info">实际尺寸 28x28 | 放大 10 倍</p></div><!-- 3. 控制按钮(功能操作) --><div class="buttons"><button @click="clearCanvas" :disabled="isLoading">清除画布</button><button @click="predictDigit" :disabled="isLoading">{{ isLoading ? '识别中...' : '识别' }}</button><button @click="toggleDebug">显示/隐藏调试</button></div><!-- 4. 结果与错误反馈 --><div class="result" v-if="recognitionResult">识别结果:{{ recognitionResult }}</div><div class="error" v-if="errorMessage">错误:{{ errorMessage }}</div></div>
</template>
2.3 核心逻辑实现(App.vue
的<script setup>
)
2.3.1 响应式变量定义
通过Vue的ref
定义状态变量,确保视图与数据同步:
import { ref, onMounted, nextTick, watch } from 'vue';
import axios from 'axios';// 画布DOM引用
const canvas = ref(null);
const debugCanvas = ref(null);
// 控制状态
const showDebug = ref(false); // 调试视图开关
const isDrawing = ref(false); // 绘制状态
const isLoading = ref(false); // 识别加载状态
// 结果反馈
const recognitionResult = ref(''); // 识别结果
const errorMessage = ref(''); // 错误信息
// 绘制辅助变量
let ctx = null; // 主画布上下文
let debugCtx = null; // 调试画布上下文
let lastX = 0; // 上一次绘制X坐标
let lastY = 0; // 上一次绘制Y坐标
2.3.2 画布初始化(onMounted
钩子)
画布需在DOM渲染完成后初始化,确保上下文获取成功,同时配置绘制参数(匹配模型输入要求):
onMounted(async () => {await nextTick(); // 等待DOM完全渲染// 主画布初始化(280×280,后续缩放为28×28,避免绘制精度不足)if (canvas.value) {ctx = canvas.value.getContext('2d', { willReadFrequently: true });if (ctx) {ctx.fillStyle = '#ffffff'; // 纯白背景(匹配MNIST数据集背景)ctx.fillRect(0, 0, 280, 280);ctx.lineWidth = 12; // 画笔宽度(过细会导致预处理后线条消失)ctx.strokeStyle = 'black'; // 黑色画笔(与MNIST数字颜色一致)ctx.lineCap = 'round'; // 画笔端点圆润(避免锯齿)ctx.lineJoin = 'round'; // 画笔拐角圆润(提升绘制体验)} else {errorMessage.value = '主画布初始化失败,请刷新';}}// 调试画布初始化(与主画布逻辑一致,用于预览预处理结果)if (debugCanvas.value) {debugCtx = debugCanvas.value.getContext('2d', { willReadFrequently: true });if (debugCtx) {debugCtx.fillStyle = '#ffffff';debugCtx.fillRect(0, 0, 280, 280);} else {console.warn('调试画布初始化失败(不影响主功能)');}}
});
2.3.3 绘制逻辑(鼠标事件处理)
通过mousedown
/mousemove
/mouseup
事件实现连续绘制,需处理画布缩放导致的坐标偏移
问题:
// 开始绘制(记录初始坐标)
function startDrawing(e) {if (!ctx) return;isDrawing.value = true;const rect = canvas.value.getBoundingClientRect(); // 获取画布在页面中的位置// 计算画布内真实坐标(解决浏览器缩放导致的坐标偏差)lastX = (e.clientX - rect.left) * (canvas.value.width / rect.width);lastY = (e.clientY - rect.top) * (canvas.value.height / rect.height);ctx.beginPath();ctx.moveTo(lastX, lastY);ctx.lineTo(lastX + 0.1, lastY + 0.1); // 绘制初始点(避免点击不拖动无痕迹)ctx.stroke();
}// 实时绘制
function draw(e) {if (!ctx || !isDrawing.value) return;const rect = canvas.value.getBoundingClientRect();const x = (e.clientX - rect.left) * (canvas.value.width / rect.width);const y = (e.clientY - rect.top) * (canvas.value.height / rect.height);ctx.lineTo(x, y); // 连接上一坐标与当前坐标ctx.stroke();lastX = x; // 更新上一坐标lastY = y;
}// 结束绘制
function stopDrawing() {isDrawing.value = false;
}
2.3.4 图像预处理(关键步骤)
模型输入要求为1×1×28×28灰度图(batch×通道×高×宽)+ 归一化
,需通过辅助函数实现转换:
2.3.4.1 画布空检测(checkCanvasEmpty
)
避免前端发送空图像请求,通过亮度阈值判断是否有绘制内容:
async function checkCanvasEmpty() {return new Promise((resolve) => {if (!ctx) { resolve(true); return; }const imageData = ctx.getImageData(0, 0, 280, 280);const data = imageData.data; // 像素数据(RGBA,每4个值对应一个像素)const threshold = 250; // 亮度阈值(纯白亮度255,低于250视为有绘制)for (let i = 0; i < data.length; i += 4) {const brightness = (data[i] + data[i+1] + data[i+2]) / 3; // 计算亮度(灰度值)if (brightness < threshold) {resolve(false); // 有绘制内容return;}}resolve(true); // 无绘制内容});
}
2.3.4.2 28×28灰度转换与反转(canvasTo28x28Gray
)
MNIST数据集为黑底白字
,而前端绘制是白底黑字
,需反转颜色;同时缩放为28×28:
function canvasTo28x28Gray(canvasEl) {return new Promise((resolve) => {// 1. 创建临时画布(28×28,模型输入尺寸)const tempCanvas = document.createElement('canvas');tempCanvas.width = 28;tempCanvas.height = 28;const tempCtx = tempCanvas.getContext('2d');if (!tempCtx) { resolve({ imgBlob: null, tempCanvas: null }); return; }// 2. 缩放绘制(保持比例居中,避免拉伸)tempCtx.fillStyle = '#ffffff';tempCtx.fillRect(0, 0, 28, 28); // 填充纯白背景const scale = Math.min(28 / canvasEl.width, 28 / canvasEl.height); // 等比例缩放const xOffset = (28 - canvasEl.width * scale) / 2; // X轴居中偏移const yOffset = (28 - canvasEl.height * scale) / 2; // Y轴居中偏移tempCtx.drawImage(canvasEl,0, 0, canvasEl.width, canvasEl.height, // 源图像区域xOffset, yOffset, canvasEl.width * scale, canvasEl.height * scale // 目标绘制区域);// 3. 灰度转换与颜色反转(匹配MNIST数据分布)const imageData = tempCtx.getImageData(0, 0, 28, 28);const data = imageData.data;for (let i = 0; i < data.length; i += 4) {const brightness = (data[i] + data[i+1] + data[i+2]) / 3; // 灰度值const inverted = 255 - brightness; // 反转:白底黑字→黑底白字data[i] = data[i+1] = data[i+2] = inverted; // RGB通道统一为反转后值data[i+3] = 255; // 透明度保持100%}tempCtx.putImageData(imageData, 0, 0);// 4. 生成Blob(用于FormData传输)tempCanvas.toBlob((blob) => {resolve({ imgBlob: blob, tempCanvas: tempCanvas });}, 'image/png', 1.0); // 无损压缩,避免图像细节丢失});
}
2.3.5 后端请求逻辑(predictDigit
)
通过axios
发送POST请求,传递图像Blob,处理响应与错误:
async function predictDigit() {if (!ctx) { errorMessage.value = '画布未初始化,请刷新'; return; }isLoading.value = true;errorMessage.value = '';try {// 步骤1:检查画布是否有内容const isEmpty = await checkCanvasEmpty();if (isEmpty) {errorMessage.value = '请先绘制数字';isLoading.value = false;return;}// 步骤2:预处理图像(转为28×28灰度Blob)const { imgBlob, tempCanvas } = await canvasTo28x28Gray(canvas.value);if (!imgBlob) { throw new Error('图像转换失败,无法生成有效数据'); }// 步骤3:预览调试图像(若开启调试)if (showDebug.value && debugCtx && tempCanvas) {debugCtx.drawImage(tempCanvas, 0, 0, 280, 280); // 放大10倍显示}// 步骤4:构建FormData(后端接收文件格式)const formData = new FormData();formData.append('file', imgBlob, 'digit.png'); // 参数名'file'需与后端一致// 步骤5:发送请求(不手动设置Content-Type,axios自动处理边界符)const response = await axios.post('http://localhost:8000/predict', // 后端接口地址formData);// 步骤6:处理响应(验证数据格式)if (response.data && 'predicted_digit' in response.data) {recognitionResult.value = response.data.predicted_digit;} else {throw new Error('后端返回数据格式异常');}} catch (error) {// 精细化错误提示(便于排查问题)if (error.response) {// 后端返回错误(如422参数错误、500服务器错误)errorMessage.value = `识别失败:${error.response.status} - ${error.response.data?.error || error.response.data?.detail || '未知错误'}`;} else if (error.request) {// 无响应(后端未启动、跨域问题)errorMessage.value = '识别失败:无法连接后端服务,请检查后端是否运行';} else {// 前端本地错误(如图像转换失败)errorMessage.value = `识别失败:${error.message}`;}console.error('预测错误详情:', error);} finally {isLoading.value = false; // 无论成功失败,结束加载状态}
}
这里简单说下图像Blob,图像Blob(Binary Large Object)简单说就是以二进制形式存储的图像文件数据,比如PNG、JPG格式的图像在计算机中实际存储的字节流,就属于Blob。
在项目里,前端把画布绘制的内容(28×28灰度图)转成Blob,是因为:
- 后端接口接收的是“文件”类型数据(
UploadFile
),Blob能模拟文件的二进制格式; - 配合
FormData
(表单数据)传递时能保持图像的原始编码,避免文本格式转换导致的数据损坏。
比如项目中canvasTo28x28Gray
函数里,通过tempCanvas.toBlob(...)
生成Blob,再用formData.append('file', imgBlob, 'digit.png')
附加到请求里,就能让后端像接收本地图片文件一样解析它。
2.4 样式设计(App.vue
的<style scoped>
)
样式保证交互友好性,没有放过多冗杂的东西,核心代码如下:
<style scoped>
.container {text-align: center;padding: 20px;max-width: 600px;margin: 0 auto; /* 容器居中 */
}
canvas {border: 2px solid #ccc;margin: 10px auto;display: block;background-color: #ffffff; /* 匹配画布初始化背景 */touch-action: none; /* 禁止浏览器默认触摸行为(适配移动端) */
}
.debug-section {margin-top: 20px;padding: 15px;background-color: #f9f9f9;border-radius: 8px; /* 圆角提升美观度 */
}
.debug-info {color: #666;font-size: 14px;margin-top: 5px;
}
.buttons {margin: 20px 0;
}
button {padding: 10px 20px;margin: 0 10px;cursor: pointer;background-color: #42b983; /* Vue默认主题色,辨识度高 */color: white;border: none;border-radius: 4px;transition: opacity 0.3s; /* hover过渡效果 */
}
button:disabled {background-color: #ccc;cursor: not-allowed; /* 禁用状态光标提示 */opacity: 0.7;
}
button:hover:not(:disabled) {opacity: 0.8; /* hover时降低透明度,反馈交互 */
}
.result {font-size: 20px;margin-top: 20px;color: #42b983; /* 成功颜色 */
}
.error {font-size: 16px;color: #e53e3e; /* 错误颜色 */margin-top: 10px;
}
</style>
3 后端实现(FastAPI + PyTorch)
3.1 后端核心功能定位
后端需解决如何接收前端图像
、如何用模型预测
和如何返回结果
这三个问题,核心是提供高可用的预测接口,确保与前端数据格式兼容、与模型输入匹配。
3.2 FastAPI服务搭建
3.2.1 初始化FastAPI实例
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torch.nn as nn
from PIL import Image
import numpy as np# 初始化FastAPI应用
app = FastAPI()
3.2.2 跨域配置(关键)
前端(默认5173端口)与后端(8000端口)端口不同,会触发浏览器跨域拦截,需配置CORSMiddleware
:
app.add_middleware(CORSMiddleware,allow_origins=["*"], # 开发环境允许所有源(生产环境需指定具体域名)allow_credentials=True, # 允许携带Cookie(本项目暂用不到,保留扩展性)allow_methods=["*"], # 允许所有HTTP方法(GET/POST等)allow_headers=["*"], # 允许所有请求头
)
3.3 LeNet5模型定义(与训练脚本一致)
模型结构必须与训练时完全相同,否则权重加载失败。LeNet5是经典CNN架构,适配MNIST数据:
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 网络层序列(卷积→激活→池化→卷积→激活→池化→卷积→激活→展平→全连接→激活→全连接)self.net = nn.Sequential(# C1层:1→6通道,5×5卷积核,padding=2(保持28×28输出)nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),nn.Tanh(), # 激活函数(LeNet5原设计,引入非线性)nn.AvgPool2d(kernel_size=2, stride=2), # S2层:2×2平均池化,输出14×14# C3层:6→16通道,5×5卷积核(无padding,输出10×10)nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2), # S4层:输出5×5# C5层:16→120通道,5×5卷积核(输出1×1,等效全连接)nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),nn.Tanh(),nn.Flatten(), # 展平:120×1×1→120维向量# F6层:全连接,120→84nn.Linear(in_features=120, out_features=84),nn.Tanh(),# 输出层:全连接,84→10(对应0-9数字)nn.Linear(in_features=84, out_features=10))# 前向传播(定义数据流动路径)def forward(self, x):return self.net(x)
3.4 模型加载与图像预处理
3.4.1 模型初始化与权重加载
加载训练生成的LeNet5_mnist.pth
权重,切换为评估模式(禁用训练相关层):
# 初始化模型
model = LeNet5()
# 加载权重(map_location='cpu'适配无GPU环境)
state_dict = torch.load('LeNet5_mnist.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict) # 权重参数映射到模型
model.eval() # 切换为评估模式(关键:禁用Dropout/BatchNorm等训练层)
3.4.2 图像预处理函数(preprocess_image
)
前端传入的是28×28 PNG图像,需转为模型要求的1×1×28×28张量+归一化
:
def preprocess_image(image):# 1. 转为灰度图(即使前端已处理,后端二次确认,避免格式错误)image = image.convert('L') # 'L'模式为单通道灰度图# 2. 确保尺寸为28×28(前端可能因异常未缩放,后端兜底)image = image.resize((28, 28), Image.Resampling.LANCZOS) # 高质量插值缩放# 3. 转为numpy数组并归一化(匹配训练时的数据分布)image = np.array(image, dtype=np.float32) # 转为32位浮点数数组mean = 0.1307 # MNIST数据集均值(训练时计算,需固定)std = 0.3081 # MNIST数据集标准差(训练时计算,需固定)image = (image / 255.0 - mean) / std # 步骤:0-255→0-1→标准化(均值0,标准差1)# 4. 调整维度(模型输入:batch×通道×高×宽)image = np.expand_dims(image, axis=0) # 增加通道维度:(28,28)→(1,28,28)image = np.expand_dims(image, axis=0) # 增加batch维度:(1,28,28)→(1,1,28,28)# 5. 转为PyTorch张量return torch.tensor(image)
3.5 预测接口实现(/predict
)
定义POST接口,接收前端UploadFile
类型文件,处理流程为读取图像→预处理→预测→返回结果
:
@app.post("/predict")
async def predict_digit(file: UploadFile = File(...)):try:# 1. 打印调试信息(便于排查文件接收问题)print(f"收到文件: {file.filename}, 类型: {file.content_type}")# 2. 读取图像(PIL.Image打开)image = Image.open(file.file)print(f"原始图像 - 尺寸: {image.size}, 模式: {image.mode}")# 3. 图像预处理input_tensor = preprocess_image(image)print(f"预处理后 - 张量维度: {input_tensor.shape}, 数据类型: {input_tensor.dtype}")# 4. 模型预测(禁用梯度计算,节省资源)with torch.no_grad():output = model(input_tensor) # 模型输出:(1,10)(1个样本,10个类别概率)predicted_digit = torch.argmax(output, dim=1).item() # 取概率最大的类别# 5. 返回结果(JSON格式,前端可直接解析)return {"predicted_digit": predicted_digit}except Exception as e:# 异常捕获(打印错误信息,返回错误提示)print(f"处理请求时出错: {str(e)}")return {"error": str(e)}# 启动服务(当脚本直接运行时)
if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000) # 0.0.0.0允许局域网访问,端口8000
4 模型训练(PyTorch + MNIST)
4.1 训练核心目标
生成可复用的权重文件(LeNet5_mnist.pth
),该模型在MNIST测试集上准确率为98.17%,准确率还算不错,用它来为后端提供预测能力。
4.2 训练脚本实现(CNN_Proj.py
)
4.2.1 数据准备(prepare_data
)
加载MNIST数据集,应用与后端一致的预处理(归一化),用DataLoader
按批次加载:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimSun']
plt.rcParams['axes.unicode_minus'] = Falsedef prepare_data():# 数据转换 pipeline(与后端预处理逻辑一致)transform = transforms.Compose([transforms.ToTensor(), # 转为张量:(H,W,C)→(C,H,W),值归一化到0-1transforms.Normalize(0.1307, 0.3081) # 标准化(均值+标准差)])# 加载训练集(train=True),自动下载到./dataset/mnist/train_dataset = datasets.MNIST(root='./dataset/mnist/',train=True,download=True,transform=transform)# 加载测试集(train=False)test_dataset = datasets.MNIST(root='./dataset/mnist/',train=False,download=True,transform=transform)# 创建DataLoader(按批次加载,训练集打乱)train_loader = DataLoader(train_dataset,batch_size=256, # 批次大小(根据内存调整,256兼顾速度与内存)shuffle=True # 训练集打乱,增强泛化能力)test_loader = DataLoader(test_dataset,batch_size=256,shuffle=False # 测试集无需打乱)return train_loader, test_loader
4.2.2 模型训练(train_model
)
定义训练循环,包含“前向传播→损失计算→反向传播→参数更新”核心步骤:
def train_model(model, train_loader, epochs=5, lr=0.9):# 1. 损失函数:交叉熵损失(分类任务专用,含Softmax激活)criterion = nn.CrossEntropyLoss()# 2. 优化器:随机梯度下降(SGD),lr=0.9为LeNet5经典学习率optimizer = torch.optim.SGD(model.parameters(), lr=lr)# 3. 记录损失(用于绘制曲线,观察训练效果)train_losses = []# 4. 训练循环print("\n开始训练...")for epoch in range(epochs):model.train() # 切换为训练模式(启用Dropout/BatchNorm)total_loss = 0.0# 遍历训练集批次for batch_idx, (images, labels) in enumerate(train_loader):# 前向传播:输入图像,获取模型输出outputs = model(images)# 计算损失:输出与真实标签的差异loss = criterion(outputs, labels)# 反向传播与参数更新optimizer.zero_grad() # 清空上一轮梯度(避免累积)loss.backward() # 反向传播计算梯度optimizer.step() # 根据梯度更新模型参数# 记录损失train_losses.append(loss.item())total_loss += loss.item()# 每100个批次打印一次中间结果if (batch_idx + 1) % 100 == 0:print(f"轮次 [{epoch+1}/{epochs}], 批次 [{batch_idx+1}/{len(train_loader)}], "f"当前批次损失: {loss.item():.4f}")# 打印本轮平均损失avg_loss = total_loss / len(train_loader)print(f"轮次 [{epoch+1}/{epochs}] 平均损失: {avg_loss:.4f}")# 5. 绘制损失曲线(直观观察训练收敛情况)plt.figure(figsize=(10, 4))plt.plot(train_losses, label='训练损失')plt.xlabel('批次')plt.ylabel('损失值')plt.title('训练损失变化曲线')plt.legend()plt.show()# 6. 保存模型权重(仅保存状态字典,节省空间)torch.save(model.state_dict(), 'LeNet5_mnist.pth')print(f"模型已保存为 'LeNet5_mnist.pth'")return model, train_losses
4.2.3 模型测试(test_model
)
评估模型在测试集上的准确率,验证泛化能力:
def test_model(model, test_loader):model.eval() # 切换为评估模式correct = 0 # 正确预测数total = 0 # 总样本数# 禁用梯度计算(测试阶段无需更新参数)with torch.no_grad():print("\n开始测试...")for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1) # 取概率最大的类别total += labels.size(0)correct += (predicted == labels).sum().item() # 统计正确数# 计算并打印准确率accuracy = 100 * correct / totalprint(f"测试集准确率: {accuracy:.2f}%")return accuracy
4.2.4 主函数(串联训练流程)
def main():# 步骤1:准备数据train_loader, test_loader = prepare_data()print("数据准备完成,训练集样本数:", len(train_loader.dataset), "测试集样本数:", len(test_loader.dataset))# 步骤2:初始化模型(与后端LeNet5完全一致)model = LeNet5()print("\nLeNet-5模型初始化完成")# 步骤3:训练模型trained_model, losses = train_model(model, train_loader, epochs=5)# 步骤4:测试模型test_model(trained_model, test_loader)if __name__ == "__main__":main()
5 完整项目操作流程
5.1 前置准备
- 安装基础环境:
- 安装Python 3.8+(官网下载)
- 安装Node.js 16+(官网下载)
- 搭建项目目录:
- 在
D:\ProjectPython\
下创建DNN_CNN
文件夹(根目录)。 - 在
DNN_CNN
下创建mnist-frontend
文件夹(前端目录)。
- 在
- 安装依赖:
- 打开命令提示符(CMD),执行Python依赖安装:
pip install fastapi uvicorn torch torchvision pillow numpy
- 进入前端目录,执行Vue依赖安装:
cd D:\ProjectPython\DNN_CNN\mnist-frontend
npm create vue@latest . # 初始化Vue项目,全部选“NO”(简化配置)
npm install axios
5.2 模型训练(可选,已有权重可跳过)
- 在
DNN_CNN
根目录创建CNN_Proj.py
,第4章的训练脚本程序放在该py文件里。 - 运行训练脚本:
cd D:\ProjectPython\DNN_CNN python CNN_Proj.py
- 等待训练完成,根目录会生成
LeNet5_mnist.pth
(权重文件),这个时候可以管擦测试集准确率,一般来说满足≥95%就可以了。
比如我这边自己训练的,
从训练结果来看,这个 LeNet-5 模型在 MNIST 测试集上达到了98.17% 的准确率,对于基础的手写数字识别任务来说,这个性能算是比较理想的,直接用于简单的手写数字识别这个实际场景是足够的。
5.3 后端部署
- 在
DNN_CNN
根目录创建main.py
,程序详见第3章的后端脚本程序。 - 确保
LeNet5_mnist.pth
在根目录下,启动后端服务:python main.py
- 看到“Uvicorn running on http://0.0.0.0:8000”表示启动成功,不要关闭CMD窗口。
这里有两个点要说清楚,
第一,如果直接在 Python 里运行 main.py
(比如点击 IDE 的“运行”按钮),程序会加载模型 → 定义 FastAPI 实例 → 定义路由,但不会但不会启动 Web 服务!代码里的 API 接口(/predict
)根本没法被外部访问, Postman 也连不上。
第二,uvicorn main:app --reload
是干啥的? uvicorn
是一个 ASGI 服务器,作用是:
- 找到你的
main.py
文件,加载里面的app = FastAPI()
实例 - 启动一个 Web 服务,让你的 API(
/predict
)能被外部访问(比如 Postman、前端页面 ) --reload
:文件改动时自动重启服务(开发时超方便,不用手动重启 )
main.py完整程序如下:
# 后端 main.py(PyTorch 版本)
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torch.nn as nn
from PIL import Image
import numpy as npapp = FastAPI()# 允许跨域
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)# 定义与CNN_Proj.py中一致的LeNet5模型结构
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),nn.Tanh(),nn.Flatten(),nn.Linear(in_features=120, out_features=84),nn.Tanh(),nn.Linear(in_features=84, out_features=10))def forward(self, x):return self.net(x)# 初始化模型
model = LeNet5()# 加载权重(无需修改键名,直接匹配)
state_dict = torch.load('LeNet5_mnist.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval() # 切换为评估模式# 图像预处理(适配MNIST数据集的预处理方式)
def preprocess_image(image):# 确保图像转为灰度图(即使前端已处理,后端再次确认)image = image.convert('L') # 转为灰度图# 确保图像尺寸为28x28(即使前端已处理,后端再次确认)image = image.resize((28, 28), Image.Resampling.LANCZOS) # 使用高质量插值方法# 转换为numpy数组并归一化image = np.array(image, dtype=np.float32) # 转为数组# 按照训练时的方式归一化(MNIST的均值和标准差)mean = 0.1307std = 0.3081image = (image / 255.0 - mean) / std # 先归一化到0-1再标准化# 确保输入维度正确image = np.expand_dims(image, axis=0) # 增加通道维度 (1,28,28)image = np.expand_dims(image, axis=0) # 增加batch维度 (1,1,28,28)return torch.tensor(image)# 预测接口
@app.post("/predict")
async def predict_digit(file: UploadFile = File(...)):try:# 打印文件基本信息用于调试print(f"收到文件: {file.filename}, 类型: {file.content_type}")# 读取图像image = Image.open(file.file)print(f"原始图像 - 尺寸: {image.size}, 模式: {image.mode}") # 检查图像初始状态# 预处理input_tensor = preprocess_image(image)print(f"预处理后 - 张量维度: {input_tensor.shape}, 数据类型: {input_tensor.dtype}") # 检查处理后状态# 预测with torch.no_grad():output = model(input_tensor)predicted_digit = torch.argmax(output, dim=1).item()return {"predicted_digit": predicted_digit}except Exception as e:# 打印异常信息用于调试print(f"处理请求时出错: {str(e)}")return {"error": str(e)}if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000)```接下来简单展示一下启动步骤:
#### 1. 打开终端,进入 `main.py` 所在目录
以我的文件结构来举例:
D:\ProjectPython\DNN_CNN
├── main.py
├── CNN_Proj.py
└── LeNet5_mnist.pth
在 **VS Code** 里:
- 点击左侧“资源管理器”,找到 `DNN_CNN` 文件夹
- 点击顶部菜单 **终端 → 新建终端**(会自动进入当前目录 )
- 也可以直接用cd + 文件路径#### 2. 运行 `uvicorn` 命令
在终端里输入:
```bash
uvicorn main:app --reload
main:app
:告诉uvicorn
:- 找
main.py
文件(main
) - 加载里面的
app = FastAPI()
实例(app
)
- 找
--reload
:开发模式,改代码后自动重启
3. 看启动结果
如果成功,终端会显示:
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [12345]
INFO: Started server process [12346]
INFO: Waiting for application startup.
INFO: Application startup complete.
这说明:
- 你的 API 服务启动了,地址是
http://127.0.0.1:8000
- 现在可以用 Postman 访问
http://127.0.0.1:8000/predict
测试
如果想要保险起见,可以先用下面这一步来测试一下,测试下API的情况。
4. 测试 API(用 Postman 或浏览器)
打开 Postman:
- 请求方法:POST
- URL:
http://127.0.0.1:8000/predict
- ** Body → form-data**:
Key
选file
,类型选File
Value
选一张手写数字的图片(28x28 黑白图最佳 )
发送请求后,就能看到返回的 predicted_digit
(识别结果 )
打开后配置请求信息:
-
Step 1:选请求方法 + 填 URL
- 选 POST(必须和你
main.py
里的@app.post("/predict")
对应); - 中间 URL 输入框,填
http://127.0.0.1:8000/predict
(就是你 FastAPI 服务的地址 + 接口名)。
- 选 POST(必须和你
-
Step 2:配置 Body(上传图片)
- 点击请求下方的 “Body” 标签 → 勾选 “Form Data”(表单上传,和
main.py
接收UploadFile
对应); - 第一行“Key”输入
file
(必须和main.py
里predict(file: UploadFile = File(...))
的参数名一致); - 第一行“Value”右侧,点击 “File” 按钮(默认是“Text”,要改成文件上传),然后选择一张你的手写数字图片(28x28 黑白图最佳,手机拍的手写数字照片也能试)。
- 点击请求下方的 “Body” 标签 → 勾选 “Form Data”(表单上传,和
-
Step 3:发送请求
- 点击右上角的 “Send” 按钮(蓝色箭头),发送请求。
发送后,右侧会显示服务器返回的结果:
- 成功情况:如果返回类似
{"predicted_digit": 5}
,说明模型识别出图片里的数字是 5,API 调用成功! - 常见问题排查:
- 若显示“Connection refused”:检查 FastAPI 服务是否启动(终端里的
uvicorn
命令有没有在运行); - 若显示“找不到文件”:检查
main.py
里torch.load("LeNet5_mnist.pth")
的模型路径是否正确,确保LeNet5_mnist.pth
和main.py
在同一目录; - 若识别结果错误:检查
preprocess_image
函数的预处理逻辑(比如是否转灰度、是否 resize 到 28x28),要和训练时完全一致。
- 若显示“Connection refused”:检查 FastAPI 服务是否启动(终端里的
5.4 前端部署
首先确保你的 Node.js 环境已经准备好,接下来用 Vue3 + Vite 实现手写数字识别的前端界面并和后端 API 打通:
5.4.1 创建 Vue3 + Vite 项目
- 打开终端(CMD/PowerShell/VS Code 终端都可以 );
- 创建项目(按顺序执行 ):
# 1. 创建 Vue3 项目(项目名 mnist-frontend,模板选 vue)
npm create vite@latest mnist-frontend -- --template vue# 2. 进入项目目录
cd mnist-frontend# 3. 安装依赖(等待安装完成)
npm install# 4. 启动开发环境(启动后,浏览器访问 http://127.0.0.1:5173)
npm run dev
执行完后,浏览器会自动打开 Vue3 初始页面(或手动访问 http://127.0.0.1:5173
),看到 Vue 的欢迎界面,说明项目创建成功。
5.4.2 编写前端界面
在 VS Code 中打开项目目录 mnist-frontend
,找到 src/App.vue
文件,替换成以下完整程序:
<template><div class="container"><h1>手写数字识别</h1><!-- 主画布 --><canvas ref="canvas" width="280" height="280" @mousedown="startDrawing" @mousemove="draw" @mouseup="stopDrawing"@mouseleave="stopDrawing"></canvas><!-- 调试画布(v-show 保持 DOM 存在) --><div class="debug-section" v-show="showDebug"><h3>预处理后图像(28x28 放大)</h3><canvas ref="debugCanvas" width="280" height="280"></canvas><p class="debug-info">实际尺寸 28x28 | 放大 10 倍</p></div><!-- 控制按钮 --><div class="buttons"><button @click="clearCanvas" :disabled="isLoading">清除画布</button><button @click="predictDigit" :disabled="isLoading">{{ isLoading ? '识别中...' : '识别' }}</button><button @click="toggleDebug">显示/隐藏调试</button></div><!-- 结果与错误提示 --><div class="result" v-if="recognitionResult">识别结果:{{ recognitionResult }}</div><div class="error" v-if="errorMessage">错误:{{ errorMessage }}</div></div>
</template>
<script setup>
import { ref, onMounted, nextTick, watch } from 'vue';
import axios from 'axios';
// 响应式变量
const canvas = ref(null);
const debugCanvas = ref(null);
const showDebug = ref(false);
const isDrawing = ref(false);
const isLoading = ref(false);
const recognitionResult = ref('');
const errorMessage = ref('');
let ctx = null;
let debugCtx = null;
let lastX = 0;
let lastY = 0;
// 初始化画布(确保 DOM 渲染完成)
onMounted(async () => {await nextTick(); // 等待 DOM 完全渲染// 主画布初始化if (canvas.value) {ctx = canvas.value.getContext('2d', { willReadFrequently: true });if (ctx) {ctx.fillStyle = '#ffffff'; // 改为纯白背景,与MNIST训练数据背景一致ctx.fillRect(0, 0, 280, 280);ctx.lineWidth = 12; // 调整画笔宽度,避免预处理后线条过细ctx.strokeStyle = 'black';ctx.lineCap = 'round'; // 画笔端点圆润,避免锯齿ctx.lineJoin = 'round'; // 画笔拐角圆润,提升绘制体验} else {errorMessage.value = '主画布初始化失败,请刷新';}} else {errorMessage.value = '未找到主画布元素,请检查代码';}// 调试画布初始化(v-show 已确保 DOM 存在)if (debugCanvas.value) {debugCtx = debugCanvas.value.getContext('2d', { willReadFrequently: true });if (debugCtx) {debugCtx.fillStyle = '#ffffff';debugCtx.fillRect(0, 0, 280, 280);} else {console.warn('调试画布初始化失败(不影响主功能)');}}
});
// 监听 showDebug 变化,重新初始化调试画布
watch(showDebug, (newVal) => {if (newVal && debugCanvas.value && !debugCtx) {debugCtx = debugCanvas.value.getContext('2d', { willReadFrequently: true });if (debugCtx) {debugCtx.fillStyle = '#ffffff';debugCtx.fillRect(0, 0, 280, 280);}}
});
// 绘制逻辑 - 修复坐标计算与绘制连续性问题
function startDrawing(e) {if (!ctx) return;isDrawing.value = true;const rect = canvas.value.getBoundingClientRect();// 计算画布内真实坐标(处理画布缩放场景)lastX = (e.clientX - rect.left) * (canvas.value.width / rect.width);lastY = (e.clientY - rect.top) * (canvas.value.height / rect.height);ctx.beginPath();ctx.moveTo(lastX, lastY);// 绘制初始点(解决点击画布不拖动无痕迹问题)ctx.lineTo(lastX + 0.1, lastY + 0.1);ctx.stroke();
}
function draw(e) {if (!ctx || !isDrawing.value) return;const rect = canvas.value.getBoundingClientRect();// 计算画布内真实坐标const x = (e.clientX - rect.left) * (canvas.value.width / rect.width);const y = (e.clientY - rect.top) * (canvas.value.height / rect.height);ctx.lineTo(x, y);ctx.stroke();lastX = x;lastY = y;
}
function stopDrawing() {isDrawing.value = false;
}
// 清除画布
function clearCanvas() {if (!ctx) return;ctx.fillStyle = '#ffffff';ctx.fillRect(0, 0, 280, 280);// 清除调试画布if (debugCtx) {debugCtx.fillStyle = '#ffffff';debugCtx.fillRect(0, 0, 280, 280);}recognitionResult.value = '';errorMessage.value = '';
}
// 切换调试视图
function toggleDebug() {showDebug.value = !showDebug.value;
}
// 预测逻辑 - 修复FormData构建与错误处理
async function predictDigit() {if (!ctx) {errorMessage.value = '画布未初始化,请刷新';return;}isLoading.value = true;errorMessage.value = '';try {// 检查画布是否有内容(优化阈值,适配纯白背景)const isEmpty = await checkCanvasEmpty();if (isEmpty) {errorMessage.value = '请先绘制数字';isLoading.value = false;return;}// 转换为 28x28 灰度图(前端预处理)const { imgBlob, tempCanvas } = await canvasTo28x28Gray(canvas.value);if (!imgBlob) {throw new Error('图像转换失败,无法生成有效图像数据');}// 显示调试图像(放大)if (showDebug.value && debugCtx && tempCanvas) {debugCtx.drawImage(tempCanvas, 0, 0, 280, 280);}// 调用后端识别 - 修复FormData构建,移除手动设置Content-Type(axios自动处理)const formData = new FormData();formData.append('file', imgBlob, 'digit.png'); // 参数名改为'file',与后端UploadFile参数名匹配const response = await axios.post('http://localhost:8000/predict', formData// 移除手动设置的Content-Type,避免边界符缺失问题);// 验证响应数据格式if (response.data && 'predicted_digit' in response.data) {recognitionResult.value = response.data.predicted_digit;} else {throw new Error('后端返回数据格式异常');}} catch (error) {// 精细化错误提示if (error.response) {// 后端返回错误(如422、500)errorMessage.value = `识别失败:${error.response.status} - ${error.response.data?.error || error.response.data?.detail || '未知错误'}`;} else if (error.request) {// 无响应(如后端未启动、跨域问题)errorMessage.value = '识别失败:无法连接后端服务,请检查后端是否运行';} else {// 前端本地错误(如图像转换)errorMessage.value = `识别失败:${error.message}`;}console.error('预测错误详情:', error);} finally {isLoading.value = false;}
}
// 辅助函数:检查画布是否为空(优化阈值,适配纯白背景)
async function checkCanvasEmpty() {return new Promise((resolve) => {if (!ctx) {resolve(true);return;}const imageData = ctx.getImageData(0, 0, 280, 280);const data = imageData.data;const threshold = 250; // 纯白背景下,低于250视为有绘制内容for (let i = 0; i < data.length; i += 4) {const brightness = (data[i] + data[i+1] + data[i+2]) / 3;if (brightness < threshold) {resolve(false);return;}}resolve(true);});
}
// 辅助函数:Canvas 转 28x28 灰度图(修复图像反转逻辑,匹配MNIST)
function canvasTo28x28Gray(canvasEl) {return new Promise((resolve) => {const tempCanvas = document.createElement('canvas');tempCanvas.width = 28;tempCanvas.height = 28;const tempCtx = tempCanvas.getContext('2d');if (!tempCtx) {resolve({ imgBlob: null, tempCanvas: null });return;}// 1. 绘制时保持图像比例,避免拉伸(居中绘制)tempCtx.fillStyle = '#ffffff';tempCtx.fillRect(0, 0, 28, 28); // 先填充纯白背景// 计算缩放比例(确保图像完全放入28x28画布,保留比例)const scale = Math.min(28 / canvasEl.width, 28 / canvasEl.height);const xOffset = (28 - canvasEl.width * scale) / 2;const yOffset = (28 - canvasEl.height * scale) / 2;tempCtx.drawImage(canvasEl,0, 0, canvasEl.width, canvasEl.height,xOffset, yOffset, canvasEl.width * scale, canvasEl.height * scale);// 2. 转灰度并反转(MNIST:白底黑字 → 黑底白字,增强特征)const imageData = tempCtx.getImageData(0, 0, 28, 28);const data = imageData.data;for (let i = 0; i < data.length; i += 4) {// 计算亮度(灰度值)const brightness = (data[i] + data[i+1] + data[i+2]) / 3;// 反转:白色(高亮度)→ 黑色(0),黑色(低亮度)→ 白色(255),匹配MNIST数据分布const inverted = 255 - brightness;data[i] = data[i+1] = data[i+2] = inverted;data[i+3] = 255; // 保持不透明}tempCtx.putImageData(imageData, 0, 0);// 3. 生成Blob(指定质量,避免数据损坏)tempCanvas.toBlob((blob) => {resolve({ imgBlob: blob, tempCanvas: tempCanvas });}, 'image/png', 1.0); // 1.0表示无损压缩,确保图像细节不丢失});
}
</script>
<style scoped>
.container {text-align: center;padding: 20px;max-width: 600px;margin: 0 auto;
}
canvas {border: 2px solid #ccc;margin: 10px auto;display: block;background-color: #ffffff; /* 匹配初始化的纯白背景 */touch-action: none;
}
.debug-section {margin-top: 20px;padding: 15px;background-color: #f9f9f9;border-radius: 8px;
}
.debug-info {color: #666;font-size: 14px;margin-top: 5px;
}
.buttons {margin: 20px 0;
}
button {padding: 10px 20px;margin: 0 10px;cursor: pointer;background-color: #42b983;color: white;border: none;border-radius: 4px;transition: opacity 0.3s;
}
button:disabled {background-color: #ccc;cursor: not-allowed;opacity: 0.7;
}
button:hover:not(:disabled) {opacity: 0.8;
}
.result {font-size: 20px;margin-top: 20px;color: #42b983;
}
.error {font-size: 16px;color: #e53e3e;margin-top: 10px;
}
</style>
- 已经在
mnist-frontend/src
目录下创建好App.vue
,程序详见第2章的前端脚本程序。 - 启动前端服务:
cd D:\ProjectPython\DNN_CNN\mnist-frontend\src npm run dev
- 看到
Local: http://localhost:5173/
表示启动成功,复制链接在浏览器打开。
5.5 功能测试
- 在浏览器页面的画布上,用鼠标绘制0-9任意数字。
- 点击
显示/隐藏调试
,查看28×28预处理图像。 - 点击
识别
按钮,下方会显示识别结果
。 - 点击
清除画布
可重新绘制,测试其他数字。
结果如下,只列举部分:
当然,你在终端上也可以看到具体的信息,如果出现错误也可以从中看到是什么错误:
在前端页面上也可以通过Fn + 12
来打开浏览器后台查看具体信息。
在你创建好后,如果未更改前后端文件,后续你的启动步骤就只需要两步:
1.启动后端API服务:
uvicorn main:app --reload
2.启动前端开发环境:
npm run dev
6 问题复盘与解决
6.1 错误1:422 Unprocessable Entity(前端请求后端失败)
这个算是一开始很常见的问题,具体来说很大概率基本都是参数名与后端不匹配。
- 原因:前端FormData参数名与后端不匹配(原前端用
image
,后端需file
);手动设置Content-Type: multipart/form-data
导致请求边界符缺失。 - 解决思路:前端
formData.append('file', imgBlob, 'digit.png')
;删除axios的headers
配置,让axios自动处理。
6.2 错误2:预测结果不准确(如“3”识别为“8”)
- 原因:前端图像未反转(与MNIST黑底白字分布相反);画笔过细导致预处理后线条消失。
- 解决思路:在
canvasTo28x28Gray
中添加灰度反转(255 - brightness
);将ctx.lineWidth
设为12-15。
7 小结
7.1 收获
技术栈整合:切身体会Vue(前端交互)
、FastAPI(后端接口)
和PyTorch(CNN模型)
的前后端分离开发模式,理解各模块间的数据流转逻辑(图像→Blob→FormData→张量→预测结果)。
关键技术点:
图像预处理
:灰度转换、尺寸缩放、颜色反转、归一化,核心是“匹配模型训练时的数据分布”。模型部署
:训练权重加载、评估模式切换、无梯度预测,确保模型高效且正确运行。问题排查
:通过调试信息(如后端打印的文件尺寸、张量维度)定位数据格式问题,通过精细化错误提示快速排查接口问题。
7.2 可扩展方向
功能扩展:支持手写字母识别(替换数据集为EMNIST)、多数字识别(修改模型输出层为多分类)。
性能优化:用ResNet-18替换LeNet5提升准确率,用TensorRT加速模型推理,前端添加防抖绘制减少冗余数据。
场景适配:开发移动端页面,添加历史记录功能,部署到云服务器实现公网访问,但相关知识目前还没学完,后面有时间试试。
7.3 可复用方向
本笔记的环境搭建→代码实现→操作流程
可直接复用于图像分类类项目(如验证码识别、水果分类),只需替换三个部分:
- 数据集:将MNIST替换为目标数据集(如EMNIST、Fruits-360)。
- 模型结构:根据数据集复杂度调整CNN层数(简单任务用LeNet5,复杂任务用ResNet)。
- 前端交互:根据输入类型修改交互组件(将画布改为图片上传)。