当前位置: 首页 > news >正文

经典卷积神经网络LeNet实现(pytorch版)

LeNet卷积神经网络

  • 一、理论部分
    • 1.1 核心理论
    • 1.2 LeNet-5 网络结构
    • 1.3 关键细节
    • 1.4 后期改进
    • 1.6 意义与局限性
  • 二、代码实现
    • 2.1 导包
    • 2.1 数据加载和处理
    • 2.3 网络构建
    • 2.4 训练和测试函数
      • 2.4.1 训练函数
      • 2.4.2 测试函数
    • 2.5 训练和保存模型
    • 2.6 模型加载和预测

一、理论部分

LeNet是一种经典的卷积神经网络(CNN),由Yann LeCun等人于1998年提出,最初用于手写数字识别(如MNIST数据集)。它是CNN的奠基性工作之一,其核心思想是通过局部感受野、共享权重和空间下采样来提取有效特征


1.1 核心理论

  • 局部感受野(Local Receptive Fields)
    卷积层通过小尺寸的滤波器(如5×5)扫描输入图像,每个神经元仅连接输入图像的局部区域,从而捕捉局部特征(如边缘、纹理)

  • 共享权重(Weight Sharing)
    同一卷积层的滤波器在整张图像上共享参数,显著减少参数量,增强平移不变性

  • 空间下采样(Subsampling)
    池化层(如平均池化)降低特征图的分辨率,减少计算量并增强对微小平移的鲁棒性

  • 多层特征组合
    通过交替的卷积和池化层,逐步组合低层特征(边缘)为高层特征(数字形状)


1.2 LeNet-5 网络结构

LeNet-5是LeNet系列中最著名的版本,其结构如下(输入为32×32灰度图像):

层类型 参数说明 输出尺寸
输入层 灰度图像 32×32×1
C1层 卷积层:6个5×5滤波器,步长1,无填充 28×28×6
S2层 平均池化:2×2窗口,步长2 14×14×6
C3层 卷积层:16个5×5滤波器,步长1 10×10×16
S4层 平均池化:2×2窗口,步长2 5×5×16
C5层 卷积层:120个5×5滤波器 1×1×120
F6层 全连接层:84个神经元 84
输出层 全连接 + Softmax(10类) 10

1.3 关键细节

  • 激活函数
    原始LeNet使用Tanh或Sigmoid,现代实现常用ReLU

  • 池化方式
    原始版本使用平均池化,后续改进可能用最大池化

  • 参数量优化
    C3层并非全连接至S2的所有通道,而是采用部分连接(如论文中的连接表),减少计算量

  • 输出处理
    最后通过全连接层(F6)和Softmax输出分类概率(如0-9数字)


1.4 后期改进

  • ReLU替代Tanh:解决梯度消失问题,加速训练
  • 最大池化:更关注显著特征,抑制噪声
  • Batch Normalization:稳定训练过程
  • Dropout:防止过拟合(原LeNet未使用)

1.6 意义与局限性

  • 意义
    证明了CNN在视觉任务中的有效性,启发了现代深度学习模型(如AlexNet、ResNet)

  • 局限性
    参数量小、层数浅,对复杂数据(如ImageNet)表现不足,需更深的网络结构

LeNet的设计思想至今仍是CNN的基础,理解它有助于掌握现代卷积神经网络的演变逻辑

二、代码实现

  • LeNet 是一个经典的卷积神经网络(CNN),由 Yann LeCun 等人于 1998 年提出,主要用于手写数字识别(如 MNIST 数据集)
  • MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片
  • 总体来看,LeNet(LeNet-5)由两个部分组成:(1)卷积编码器:由两个卷积层组成(2)全连接层密集块:由三个全连接层组成

2.1 导包

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
from torchsummary import summary

2.1 数据加载和处理

# 加载 MNIST 数据集
def load_data(batch_size=64):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),  # 将图像转换为张量
        torchvision.transforms.Normalize((0.5,), (0.5,))  # 归一化
    ])
    
    # 下载训练集和测试集
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 创建 DataLoader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

2.3 网络构建

  • LeNet 的网络结构如下:
    • 卷积层 1:输入通道 1,输出通道 6,卷积核大小 5x5
    • 池化层 1:2x2 的最大池化
    • 卷积层 2:输入通道 6,输出通道 16,卷积核大小 5x5。
    • 池化层 2:2x2 的最大池化。
    • 全连接层 1:输入 16x5x5,输出 120
    • 全连接层 2:输入 120,输出 84
    • 全连接层 3:输入 84,输出 10(对应 10 个类别)
#定义LeNet网络架构
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.net=nn.Sequential(
            #卷积层1
            nn.Conv2d

相关文章:

  • Unity3D依赖注入容器使用指南博毅创为博毅创为
  • Java接口(二)
  • dp4-ai 安装教程
  • 化繁为简解决leetcode第1289题下降路径最小和II
  • 深度解剖 TCP 三次握手 四次挥手
  • LXC 导入多Linux系统
  • mybatis-genertor(代码生成)源码及扩展笔记
  • stm32F103C8T6引脚定义
  • python 的gui开发示例
  • MySQL Online DDL:演变、原理与实践
  • RAG 文档嵌入到向量数据库FAISS
  • 前沿科技:具身智能(Embodied Intelligence)详解
  • 利用cusur+claude3.7 angent模式一句提示词生成一个前端网站
  • 阿里拟收购两氢一氧公司 陈航将出任阿里集团钉钉 CEO
  • 【CV/NLP/生成式AI】
  • 二月公开赛Web-ssrfme
  • 4月1号.
  • Redis:主从复制
  • 机器学习+EEG熵进行双相情感障碍诊断的综合评估
  • Git基本操作
  • 湖南新宁一矿厂排水管破裂,尾砂及积水泄漏至河流,当地回应
  • 美乌矿产协议预计最早于今日签署
  • 当老年人加入“行为艺术基础班”
  • 运动健康|不同能力跑者,跑步前后营养补给差别这么大?
  • 北京公园使用指南
  • 黄仁勋访华期间表示希望继续与中国合作,贸促会回应