当前位置: 首页 > news >正文

pytorch深度学习-Lenet-Minist

一、核心框架:卷积神经网络(CNN)

1. 为什么需要 CNN?

传统神经网络(全连接网络)处理图像时存在两个致命问题:

  • 参数爆炸:以 28×28 的 MNIST 图像为例,输入层有 784 个神经元,若第一个全连接层设 1000 个神经元,仅这一层就有 784×1000=784,000 个参数,计算量极大;
  • 忽略空间相关性:图像中相邻像素关联性强(如数字的边缘、纹理),但全连接网络将像素视为独立特征,破坏了空间结构。

CNN 通过局部感受野权值共享池化解决了这些问题,专为网格结构数据(图像、语音)设计。

二、CNN 核心组件详解

1. 卷积层(Convolutional Layer)

卷积层是 CNN 的 "特征提取器",核心是卷积操作,模拟人类视觉系统对局部特征的感知(如边缘、纹理、形状)。

  • 卷积操作原理
    用一个卷积核(Kernel/Filter) 在输入图像上滑动,计算卷积核与对应区域的点积,输出一个特征图(Feature Map)
    例:输入是 32×32 的灰度图(单通道),用一个 5×5 的卷积核,每次滑动 1 个像素(步长 = 1),输出的特征图尺寸为:
    输出尺寸 = (输入尺寸 - 卷积核尺寸 + 2×填充) / 步长 + 1

  • 关键参数

    • in_channels:输入特征图的通道数(如灰度图为 1,RGB 图为 3);
    • out_channels:卷积核数量(每个核提取一种特征,输出对应通道的特征图);
    • kernel_size:卷积核尺寸(如 5×5);
    • stride:滑动步长(步长越大,输出特征图越小);
    • padding:边缘填充(避免边缘特征被忽略,如填充 1 圈,输入尺寸变相增加 2)。
  • 权值共享:一个卷积核在图像上滑动时,所有位置使用同一组权重,大幅减少参数。例如 5×5 的卷积核,无论图像多大,每个核仅需 25 个参数(+1 个偏置)。

2. 池化层(Pooling Layer)

池化层是 "特征压缩器",作用是降低特征图尺寸(减少计算量),同时增强平移不变性(轻微位移不影响特征)。

  • 常见类型

    • 最大池化(MaxPooling):取局部区域最大值(保留最显著特征,如边缘);
    • 平均池化(AvgPooling):取局部区域平均值(保留整体趋势)。
  • 参数

    • kernel_size:池化窗口尺寸(如 2×2);
    • stride:滑动步长(通常与 kernel_size 相同,避免重叠)。
3. 全连接层(Fully Connected Layer)

全连接层是 "分类决策器",将卷积层提取的局部特征整合为全局特征,并映射到输出类别。

  • 原理:层内每个神经元与上一层所有神经元连接(类似传统神经网络),通过矩阵乘法将特征向量转换为类别得分。
  • 注意:全连接层参数较多,通常放在网络末尾,基于卷积层提取的压缩特征做决策。

二、经典模型:LeNet-5

LeNet-5 是 1998 年由 Yann LeCun 提出的首个实用 CNN,专为手写数字识别设计(MNIST 数据集),结构简洁但包含 CNN 核心思想。

原始 LeNet-5 结构:
层类型输入尺寸核心参数输出尺寸作用
输入层32×32×1-32×32×1原始图像(MNIST 调整后)
卷积层 C132×32×16 个 5×5 卷积核,步长 1,无填充28×28×6提取边缘、角点等基础特征
池化层 S228×28×62×2 最大池化,步长 214×14×6压缩特征,保留关键信息
卷积层 C214×14×616 个 5×5 卷积核,步长 1,无填充10×10×16提取组合特征(如线条交叉)
池化层 S310×10×162×2 最大池化,步长 25×5×16进一步压缩特征
全连接层 F15×5×16=400120 个神经元120整合全局特征
全连接层 F212084 个神经元84细化特征映射
输出层8410 个神经元(对应 10 个数字)10输出类别得分

 

三、训练核心:损失、优化与流程

1. 损失函数(Loss Function)

损失函数是 "误差度量仪",量化模型预测与真实标签的差异,指导参数更新。

  • 代码用CrossEntropyLoss(交叉熵损失),专为分类任务设计:
    • 原理:结合了Softmax(将输出得分转为概率分布)和NLLLoss(负对数似然损失);
    • 公式:对每个样本,损失 = -log (预测类别概率),预测越准,损失越小。
2. 优化器(Optimizer)

优化器是 "参数调整器",根据损失函数的梯度(导数)更新网络参数,最小化损失。

  • 代码用Adam优化器,目前最常用的优化器之一:
    • 优势:结合了Momentum(模拟物理惯性,加速收敛)和RMSprop(自适应学习率,不同参数用不同步长);
    • 对比:比传统SGD(随机梯度下降)收敛更快,比RMSprop更稳定,适合大多数场景。
3. 训练流程

训练是 "迭代优化" 的过程,核心是正向传播→反向传播→参数更新

  1. 正向传播(Forward Propagation)
    输入数据通过网络计算输出(预测结果),同时计算损失。

  2. 反向传播(Backward Propagation)
    用链式法则从损失函数反向计算每个参数的梯度(即参数对损失的影响程度)。

  3. 参数更新
    优化器根据梯度调整参数(如w = w - lr × 梯度),降低损失。

  • 关键术语:
    • epoch:遍历整个训练集一次;
    • batch_size:每次输入网络的样本数;
    • iteration:每处理 1 个 batch 称为 1 次迭代。

四、评估与数据预处理

1. 评估指标

代码用准确率(Accuracy) 评估模型:
准确率 = 正确预测的样本数 / 总样本数 × 100%
适用于平衡数据集(如 MNIST,10 类样本数量相近)。

2. 数据预处理

预处理是 "数据标准化" 步骤,提升模型训练效率和稳定性:

  • Resize((32,32)):将 MNIST 原始 28×28 图像放大到 32×32,适配 LeNet 输入尺寸;
  • ToTensor():将图像从 PIL 格式转为 PyTorch 张量(维度为 [C, H, W]),并将像素值从 [0,255] 归一化到 [0,1];
  • Normalize((0.1307,), (0.3081,)):用 MNIST 数据集的均值(0.1307)和标准差(0.3081)标准化,使数据分布更稳定,加速收敛。

简易 代码实战

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备:{device}')if torch.cuda.is_available():print(f"GPU:{torch.cuda.get_device_name(0)}")print(f"当前GPU设备:{torch.cuda.current_device()}")print(f"设备属性:{torch.cuda.get_device_properties(0)}")transform = transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('../data',train=True,download=True,transform=transform)
test_dataset = datasets.MNIST('../data',train=False,download=True,transform=transform)
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=1000,shuffle=False)class LeNet(nn.Module):def __init__(self,num_classes = 10):super(LeNet,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,stride=1,padding=0),nn.ReLU())self.pool1 = nn.MaxPool2d(2,2)self.conv2 = nn.Sequential(nn.Conv2d(6,16,5,1,0),nn.ReLU())self.pool2 = nn.MaxPool2d(2,2)self.fc1 = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(120,84),nn.ReLU())self.fc3 = nn.Linear(84,num_classes)def forward(self,x):x = self.conv1(x)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = x.view(-1,16*5*5)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
model = LeNet().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)model.train()
train_loss = []
epochs = 10
for epoch in range(epochs):running_loss = 0.0for batch_idx,(data,target) in enumerate(train_loader):data,target = data.to(device),target.to(device)optimizer.zero_grad()output = model(data)loss = loss_function(output,target)loss.backward()optimizer.step()running_loss += loss.item()train_loss.append(loss.item())if batch_idx%100 == 0:print(f'Epoch:{epoch+1},Batch:{batch_idx},Loss:{loss.item():.4f}')print(f"Epoch:{epoch + 1},Average_Loss:{running_loss / len(train_loader):.4f}")model.eval()
correct = 0
total = 0
test_loss = []
with torch.no_grad():for data,target in test_loader:data,target = data.to(device),target.to(device)output = model(data)_,predicted = torch.max(output,1)total += target.size(0)correct += (predicted==target).sum().item()loss = loss_function(output,target)test_loss.append(loss.item())print(f"Test_Accuracy:{100 * correct / total:.2f}%")

 代码优化

1. 增强模型结构(添加 BatchNorm 和 Dropout)

# 改进的LeNet模型 - 添加BatchNorm和Dropout
class LeNet(nn.Module):def __init__(self, num_classes=10):super(LeNet, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=0),nn.BatchNorm2d(32),  # 添加BatchNormnn.ReLU(),nn.MaxPool2d(2, 2))self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=0),nn.BatchNorm2d(64),  # 添加BatchNormnn.ReLU(),nn.MaxPool2d(2, 2))self.fc1 = nn.Sequential(nn.Linear(64 * 5 * 5, 512),nn.BatchNorm1d(512),  # 添加BatchNormnn.ReLU(),nn.Dropout(0.5)  # 添加Dropout)self.fc2 = nn.Sequential(nn.Linear(512, 256),nn.BatchNorm1d(256),  # 添加BatchNormnn.ReLU(),nn.Dropout(0.5)  # 添加Dropout)self.fc3 = nn.Linear(256, num_classes)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(-1, 64 * 5 * 5)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x

2. 改进训练流程(添加早停和学习率调度)

# 初始化模型、损失函数和优化器
model = LeNet().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)# 训练配置
epochs = 20
best_val_loss = float('inf')
patience = 5
counter = 0# 训练循环
for epoch in range(epochs):print(f"\nEpoch {epoch+1}/{epochs}")print("-" * 30)train_loss, train_acc = train(model, train_loader, loss_function, optimizer, device, epoch, history)val_loss, val_acc = validate(model, val_loader, loss_function, device, history)# 学习率调度scheduler.step(val_loss)# 保存最佳模型和早停机制if val_loss < best_val_loss:print(f"验证损失下降 ({best_val_loss:.4f} --> {val_loss:.4f}). 保存模型...")torch.save(model.state_dict(), 'models/best_model.pth')best_val_loss = val_losscounter = 0else:counter += 1print(f"EarlyStopping 计数器: {counter}/{patience}")if counter >= patience:print(f"达到最大耐心值 {patience}. 停止训练...")break

3. 增强数据处理(添加数据增强)

# 数据预处理增强
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomRotation(10),  # 随机旋转transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 随机平移transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 创建验证集
val_size = 5000
train_indices = list(range(len(train_dataset) - val_size))
val_indices = list(range(len(train_dataset) - val_size, len(train_dataset)))train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(train_dataset, val_indices)# 创建数据加载器 - 优化参数
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True)

4. 可视化功能(训练历史可视化)

# 可视化训练历史
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.savefig('training_history.png')
plt.show()

5. 改进模型保存策略

# 创建保存模型的目录
if not os.path.exists('models'):os.makedirs('models')# 在训练循环中保存最佳模型
if val_loss < best_val_loss:print(f"验证损失下降 ({best_val_loss:.4f} --> {val_loss:.4f}). 保存模型...")torch.save(model.state_dict(), 'models/best_model.pth')best_val_loss = val_loss# 加载最佳模型进行测试
print("\n加载最佳模型进行测试...")
model.load_state_dict(torch.load('models/best_model.pth'))
test_acc = test(model, test_loader, device)

 6.完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
import time
import os# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备:{device}')if torch.cuda.is_available():print(f"GPU:{torch.cuda.get_device_name(0)}")print(f"当前GPU设备:{torch.cuda.current_device()}")print(f"设备属性:{torch.cuda.get_device_properties(0)}")# 数据预处理增强
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomRotation(10),transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载数据集
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=test_transform)# 创建验证集
val_size = 5000
train_indices = list(range(len(train_dataset) - val_size))
val_indices = list(range(len(train_dataset) - val_size, len(train_dataset)))train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(train_dataset, val_indices)# 创建数据加载器
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=1000, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, pin_memory=True)# 改进的LeNet模型 - 添加BatchNorm和Dropout
class LeNet(nn.Module):def __init__(self, num_classes=10):super(LeNet, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=0),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2, 2))self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=0),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2, 2))self.fc1 = nn.Sequential(nn.Linear(64 * 5 * 5, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5))self.fc2 = nn.Sequential(nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(),nn.Dropout(0.5))self.fc3 = nn.Linear(256, num_classes)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(-1, 64 * 5 * 5)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 初始化模型、损失函数和优化器
model = LeNet().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)# 训练函数
def train(model, train_loader, criterion, optimizer, device, epoch, history=None):model.train()train_loss = 0correct = 0total = 0start_time = time.time()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:elapsed = time.time() - start_timeprint(f'Epoch: {epoch + 1} [{batch_idx + 1}/{len(train_loader)}] 'f'Loss: {loss.item():.4f} | Acc: {100. * correct / total:.2f}% 'f'| Batch Time: {elapsed / 100:.2f}s')start_time = time.time()avg_loss = train_loss / len(train_loader)acc = 100. * correct / totalif history is not None:history['train_loss'].append(avg_loss)history['train_acc'].append(acc)print(f'Epoch: {epoch + 1} | Train Loss: {avg_loss:.4f} | Train Acc: {acc:.2f}%')return avg_loss, acc# 验证函数
def validate(model, val_loader, criterion, device, history=None):model.eval()val_loss = 0correct = 0total = 0with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)val_loss += loss.item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = val_loss / len(val_loader)acc = 100. * correct / totalif history is not None:history['val_loss'].append(avg_loss)history['val_acc'].append(acc)print(f'Validation Loss: {avg_loss:.4f} | Validation Acc: {acc:.2f}%')return avg_loss, acc# 测试函数
def test(model, test_loader, device):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()acc = 100. * correct / totalprint(f'Test Accuracy: {acc:.2f}%')return acc# 创建保存模型的目录
if not os.path.exists('models'):os.makedirs('models')# 训练历史记录
history = {'train_loss': [], 'train_acc': [],'val_loss': [], 'val_acc': []
}# 训练配置
epochs = 20
best_val_loss = float('inf')
patience = 5
counter = 0# 训练循环
print("开始训练...")
for epoch in range(epochs):print(f"\nEpoch {epoch + 1}/{epochs}")print("-" * 30)train_loss, train_acc = train(model, train_loader, loss_function, optimizer, device, epoch, history)val_loss, val_acc = validate(model, val_loader, loss_function, device, history)# 学习率调度scheduler.step(val_loss)# 保存最佳模型if val_loss < best_val_loss:print(f"验证损失下降 ({best_val_loss:.4f} --> {val_loss:.4f}). 保存模型...")torch.save(model.state_dict(), 'models/best_model.pth')best_val_loss = val_losscounter = 0else:counter += 1print(f"EarlyStopping 计数器: {counter}/{patience}")if counter >= patience:print(f"达到最大耐心值 {patience}. 停止训练...")break# 加载最佳模型进行测试
print("\n加载最佳模型进行测试...")
model.load_state_dict(torch.load('models/best_model.pth'))
test_acc = test(model, test_loader, device)# 可视化训练历史
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.savefig('training_history.png')
plt.show()print(f"最终测试准确率: {test_acc:.2f}%")

http://www.dtcms.com/a/272962.html

相关文章:

  • (LeetCode 每日一题) 3440. 重新安排会议得到最多空余时间 II (贪心)
  • RabbitMQ消息队列——三个核心特性
  • LeetCode 1652. 拆炸弹
  • AI时代的接口调试与文档生成:Apipost 与 Apifox 的表现对比
  • Leetcode刷题营第十九题:对链表进行插入排序
  • Python 网络爬虫中 robots 协议使用的常见问题及解决方法
  • 图解 BFS 路径搜索:LeetCode1971
  • 芯片I/O脚先于电源脚上电会导致Latch-up(闩锁效应)吗?
  • Logback日志框架配置实战指南
  • 5种使用USB数据线将文件从安卓设备传输到电脑的方法
  • 【JavaScript 函数、闭包与 this 绑定机制深度解析】
  • 【C语言】指针笔试题2
  • 模块三:现代C++工程实践(4篇)第二篇《性能调优:Profile驱动优化与汇编级分析》
  • FlashAttention 快速安装指南(避免长时间编译)
  • QT网络通信底层实现详解:UDP/TCP实战指南
  • Centos 7下使用C++使用Rdkafka库实现生产者消费者
  • 【LeetCode 热题 100】19. 删除链表的倒数第 N 个结点——双指针+哨兵
  • 学习 Flutter (一)
  • html的outline: none;
  • C++STL-deque
  • 1. COLA-DDD的实战
  • 【基础架构】——软件系统复杂度的来源(低成本、安全、规模)
  • 告别卡顿与慢响应!现代 Web 应用性能优化:从前端渲染到后端算法的全面提速指南
  • IDEA运行Spring项目报错:java: 警告: 源发行版 17 需要目标发行版 17,java: 无效的目标发行版: 17
  • Cargo.toml 配置详解
  • 【科研绘图系列】R语言探索生物多样性与地理分布的可视化之旅
  • 网安-解决pikachu-rce乱码问题
  • 访问Windows服务器备份SQL SERVER数据库
  • (C++)任务管理系统(文件存储)(正式版)(迭代器)(list列表基础教程)(STL基础知识)
  • x86交叉编译ros 工程给jetson nano运行