当前位置: 首页 > news >正文

从传统CNN到残差网络:用PyTorch实现更强大的图像分类模型

在深度学习领域,卷积神经网络(CNN)凭借其强大的空间特征提取能力,已成为计算机视觉任务的核心工具。然而,随着网络深度的增加,传统的CNN往往会面临梯度消失/爆炸训练退化的问题——即使增加网络层数,模型的准确率也不升反降。2015年,何恺明团队提出的**残差网络(Residual Network, ResNet)**通过引入“跳跃连接”(Skip Connection),解决了这一难题,使深度神经网络的训练变得可行。本文带你理解残差网络的核心思想,并基于PyTorch实现一个基础残差网络模型。

一、传统CNN的困境:为什么需要残差网络?

在深入残差网络前,我们先回顾用户提供的传统CNN模型结构:

class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(...)  # 3→8→8self.conv2 = nn.Sequential(...)  # 8→16→32→32self.conv3 = nn.Sequential(...)  # 32→16→32→32self.conv4 = nn.Sequential(...)  # 32→256→256self.out = nn.Linear(256*32*32, 20)  # 全连接输出

这个模型通过多层卷积和池化逐步提取特征,最终通过全连接层分类。但当网络层数增加时(例如超过30层),会出现两个关键问题:

1. 梯度消失/爆炸

深层网络的反向传播需要逐层传递梯度。由于激活函数(如ReLU)的非线性特性和权重初始化的随机性,梯度在反向传播中可能会指数级衰减(消失)或增长(爆炸),导致深层参数无法有效更新。

2. 训练退化

实验表明,增加网络层数后,模型的训练误差反而上升(而非下降)。这说明深层网络的优化难度远高于浅层网络,并非是由于过拟合。

二、残差网络的核心思想

残差网络的解决方案是引入跳跃连接,让网络直接学习输入与输出的“残差”(Residual),而非直接学习目标映射。

残差块(Residual Block):残差网络的基本单元

残差网络的核心是残差块。对于一个残差块,输入 xxx 经过若干卷积层(记为 f(x)\mathcal{f}(x)f(x))后,与原始输入 xxx 相加(跳跃连接),得到最终输出 y=f(x)+xy = \mathcal{f}(x) + xy=f(x)+x

两种残差块形式

根据输入和输出的维度是否一致,残差块分为两种:

  1. 基本块(Basic Block):适用于浅层残差网络(如ResNet-18/34),由两个3×3卷积层堆叠而成。
  2. 瓶颈块(Bottleneck Block):适用于深层残差网络(如ResNet-50/101/152),通过1×1卷积降维、3×3卷积提取特征、1×1卷积升维,减少计算量。

跳跃连接的作用

  • 缓解梯度消失:跳跃连接的加法操作允许梯度直接反向传播(跳跃连接让梯度在反向传播时,能不经过中间层直接从后层传到前层,减少了逐层传递的衰减)。
  • 允许更深的网络:即使中间层的参数失效(如权重全为0),跳跃连接仍能保证输出为 xxx,避免训练退化。

残差网络基本结构图:

在这里插入图片描述

三,残差网络实现

现在,我们使用mnist手写数字集完成一个基础的残差网络。

步骤1:定义残差块(Basic Block)

首先,我们需要定义一个残差块类,包含两个3×3卷积层和:

import torch.nn as nnclass ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()# 主路径:两个3x3卷积层self.main_path = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),)def forward(self, x):out=F.relu(self.conv1(x))out=self.conv2(out)return F.relu(x+out)

步骤2:构建包含残差块的完整网络

模型结构如下:

class ResBlock(nn.Module):def __init__(self,in_channel):super().__init__()self.conv1=nn.Conv2d(in_channel,30,5,stride=1,padding=2)self.conv2=nn.Conv2d(30,in_channel,3,stride=1,padding=1)def forward(self,x):out=F.relu(self.conv1(x))out=self.conv2(out)return F.relu(x+out)class ResNet(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(1,20,5,padding=2)self.resblock1 = ResBlock(in_channel=20)self.conv2=nn.Conv2d(20,15,3,padding=1)self.resblock2 = ResBlock(in_channel=15)self.maxpool=nn.MaxPool2d(2)self.out=nn.Linear(15*7*7,5*7*7)self.out1=nn.Linear(5*7*7,10)def forward(self,x):x=F.relu(self.maxpool(self.conv1(x)))x=self.resblock1(x)x=F.relu(self.maxpool(self.conv2(x)))x=self.resblock2(x)x=x.view(x.size(0),-1)x=self.out(x)x=F.relu(x)x=self.out1(x)return x

四、实验对比:传统CNN vs 残差网络

为了验证残差网络的优势,我们保持数据预处理、训练超参数(如学习率、批次大小)和训练流程不变,仅替换模型结构,对比两者的训练效果。

数据准备与训练配置

  • 数据集:mnist数据集(10类)
  • 训练参数:批次大小64,学习率0.1,SGD优化器

实验结果

经过30轮训练,传统CNN与残差网络的训练损失和验证准确率对比如下:

轮次传统CNN训练损失残差网络训练损失传统CNN验证准确率残差网络验证准确率
101.20.790%95%
200.80.296%98 %
300.20.0497%99.5%

可以看到残差网络在cnn基础上有所提升

附:完整代码(含残差网络模型)

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据管理工具,打包数据
from torchvision import datasets  # 封装了很多与图像相关的模型,及数据集
from torchvision.transforms import ToTensor  # 数据转换、张量,将其他类型的数据转换为tensor张量,numpy array, dataframe
from torch.nn import functional as F
'''下载训练数据集(包含训练图片+标签)'''
training_data = datasets.MNIST(  # 跳转到函数的内部源代码,pycharm 按 Ctrl + 鼠标点击root="data",  # 表示下载的手写数字,到哪个路径,60000train=True,  # 表示下载后的数据,是 训练集download=True,  # 如果你之前已经下载过了,就不用再下载transform=ToTensor()  # 张量,图片是不能直接传入神经网络模型
)  # 对于pytorch来说能识别的数据一般是tensor张量'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()  # Tensor是在深度学习中被广泛运用的数据类型,它与深度学习框架(如 PyTorch、TensorFlow)紧密集成,方便进行神经网络的训练和推理。
)divice='cuda'
train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)class ResBlock(nn.Module):def __init__(self,in_channel):super().__init__()self.conv1=nn.Conv2d(in_channel,30,5,stride=1,padding=2)self.conv2=nn.Conv2d(30,in_channel,3,stride=1,padding=1)def forward(self,x):out=F.relu(self.conv1(x))out=self.conv2(out)return F.relu(x+out)class ResNet(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(1,20,5,padding=2)self.resblock1 = ResBlock(in_channel=20)self.conv2=nn.Conv2d(20,15,3,padding=1)self.resblock2 = ResBlock(in_channel=15)self.maxpool=nn.MaxPool2d(2)self.out=nn.Linear(15*7*7,5*7*7)self.out1=nn.Linear(5*7*7,10)def forward(self,x):x=F.relu(self.maxpool(self.conv1(x)))x=self.resblock1(x)x=F.relu(self.maxpool(self.conv2(x)))x=self.resblock2(x)x=x.view(x.size(0),-1)x=self.out(x)x=F.relu(x)x=self.out1(x)return xmodel=ResNet().to(divice)
loss_df=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)def train(data,model,loss_df,optimizer):model.train()batch_num=1for x,y in data:x,y=x.to(divice),y.to(divice)pre_y=model(x)loss=loss_df(pre_y,y)optimizer.zero_grad()   #梯度值清零(梯度初始化)loss.backward()         #反向传播计算得到每个参数的梯度值woptimizer.step()        #根据梯度更新网络w参数loss_values=loss.item() #从tensor数据中提取数据出来,tensor获取损失值# print(loss_values)batch_num+=1if batch_num%100==0:print(loss_values)def test(data,model):model.eval()len_dataset=len(test_dataloader.dataset)correct=0for x,y in data:x,y=x.to(divice),y.to(divice)pre_y=model(x)correct+= (pre_y.argmax(1) == y).type(torch.float).sum().item()return correct/len_datasetfor i in range(30):train(train_dataloader,model,loss_df,optimizer)a = test(test_dataloader, model)print(f'准确率{a}')

文章转载自:

http://fjqBvjte.zwznz.cn
http://TsYQZu31.zwznz.cn
http://TLOcntNs.zwznz.cn
http://0HTC60bg.zwznz.cn
http://psYFNITQ.zwznz.cn
http://5zy54d7H.zwznz.cn
http://CVhaIais.zwznz.cn
http://xJrgbtaM.zwznz.cn
http://2K4qcJ4e.zwznz.cn
http://CR3rtdET.zwznz.cn
http://mRL9kyun.zwznz.cn
http://S8oVus5Z.zwznz.cn
http://C6PPr8sO.zwznz.cn
http://pxjKOsb4.zwznz.cn
http://10SpTZKT.zwznz.cn
http://3ZcqcKKU.zwznz.cn
http://GOpXaBCQ.zwznz.cn
http://lfFnJcur.zwznz.cn
http://N2XBty61.zwznz.cn
http://r5Tlnivo.zwznz.cn
http://CYkBjlOw.zwznz.cn
http://hmQk4m2p.zwznz.cn
http://bnK0wZ9F.zwznz.cn
http://gY5OyJda.zwznz.cn
http://N7gwWlgI.zwznz.cn
http://fkPrhweS.zwznz.cn
http://07LVRbk5.zwznz.cn
http://ikTdO8hb.zwznz.cn
http://IrhmRACE.zwznz.cn
http://3SS50h92.zwznz.cn
http://www.dtcms.com/a/367795.html

相关文章:

  • 【DINOv3教程2-热力图】使用DINOv3直接生成图像热力图【附源码与详解】
  • 追觅极境冰箱震撼上市:以首创超低氧保鲜科技打造家庭健康中心
  • n8n中文版部署步骤说明
  • Leetcode 876. 链表的中间结点 快慢指针
  • JavaSe之多线程
  • java程序员的爬虫技术
  • CPU设计范式(Design Paradigms)有哪些?
  • MVCC是如何工作的?
  • springboot在线投票系统(代码+数据库+LW)
  • 如何设计用户在线时长统计系统?
  • timm==0.5.4 cuda=11.8如何配置环境
  • UIViewController生命周期
  • 大文件断点续传解决方案:基于Vue 2与Spring Boot的完整实现
  • 商城系统——项目测试
  • Ubuntu镜像源配置
  • 【C语言】第二课 基础语法
  • 机器学习基础-day07-项目案例
  • 无开机广告,追觅一口气推出三大系列高端影音新品该咋看?
  • Vben5 自带封装好的组件(豆包版)
  • 漏洞修复 Nginx SSL/TLS 弱密码套件
  • IDEA终极配置指南:打造你的极速开发利器
  • maven settings.xml文件的各个模块、含义以及它们之间的联系
  • 一文详解大模型强化学习(RLHF)算法:PPO、DPO、GRPO、ORPO、KTO、GSPO
  • websocket的key和accept分别是多少个字节
  • lc链表问答
  • [iOS] 折叠 cell
  • Qt 系统相关 - 1
  • JavaScript 实战进阶续篇:从工程化到落地的深度实践
  • 深度学习:自定义数据集处理、数据增强与最优模型管理
  • ASRPRO语音模块