当前位置: 首页 > news >正文

「日拱一码」083 深度学习——残差网络

目录

残差网络(ResNet)介绍

核心思想与要解决的问题

解决方案:残差学习(Residual Learning)

为什么有效?

代码示例


残差网络(ResNet)介绍

核心思想与要解决的问题

在深度学习领域,一个直觉是:网络越深(层数越多),其能够学习到的特征就越复杂,性能也应该越好。然而,实验发现,当网络深度增加到一定程度时,模型的准确率会达到饱和,甚至开始迅速下降。这种现象被称为退化问题(Degradation Problem)

退化问题并非由过拟合引起(因为训练误差也会随之增加),而是因为深度网络难以训练,尤其是存在梯度消失/爆炸等问题,使得深层网络难以被优化。

解决方案:残差学习(Residual Learning)

ResNet 的作者何恺明等人提出了一个革命性的思路:与其让堆叠的层直接学习一个潜在的映射 H(x),不如让它们学习一个残差映射(Residual Mapping)

  • 原始映射H(x) = desired underlying mapping
  • 残差映射F(x) = H(x) - x
  • 目标映射变为H(x) = F(x) + x

这里的 x 是输入,也称为恒等映射(Identity Mapping)快捷连接(Shortcut Connection)

通过这种结构,学习的目标从 H(x) 变成了 F(x),即学习输出与输入之间的残差(差值)

为什么有效?

  1. 缓解梯度消失:梯度可以直接通过快捷连接反向传播到更浅的层,使得深层网络的训练变得更加容易。
  2. 恒等映射是高效的:如果某一层的输出已经是最优了(即 H(x) = x 是最优解),那么将残差 F(x) 学习为 0 比学习一个恒等映射要容易得多(F(x) = 0 比 F(x) = x 更简单)。
  3. 集成行为:ResNet 可以被看作是一系列路径集合的浅层网络,改善了网络的表达能力和训练动态。

ResNet 的提出使得构建成百上千层的深度神经网络成为可能,并在图像分类、目标检测等计算机视觉任务上取得了突破性的成果。

代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 设置设备(GPU如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 1. 定义残差块 (Residual Block)
class BasicBlock(nn.Module):expansion = 1def __init__(self, in_planes, planes, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)# 快捷连接 (Shortcut Connection)self.shortcut = nn.Sequential()if stride != 1 or in_planes != self.expansion * planes:self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(self.expansion * planes))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)  # 残差连接out = F.relu(out)return out# 2. 定义小型ResNet模型
class SmallResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(SmallResNet, self).__init__()self.in_planes = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)# 残差层self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)# 分类层self.linear = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, planes, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_planes, planes, stride))self.in_planes = planes * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = F.avg_pool2d(out, 4)  # CIFAR-10最终特征图尺寸为4x4out = out.view(out.size(0), -1)out = self.linear(out)return out# 3. 数据预处理和加载
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)# 4. 初始化模型、损失函数和优化器
model = SmallResNet(BasicBlock, [2, 2, 2, 2]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)# 5. 训练和测试函数
def train(epoch):model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()acc = 100. * correct / totalprint(f'Epoch: {epoch} | Train Loss: {train_loss / (batch_idx + 1):.3f} | Acc: {acc:.3f}%')def test(epoch):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()acc = 100. * correct / totalprint(f'Test Loss: {test_loss / (batch_idx + 1):.3f} | Acc: {acc:.3f}%')return accif __name__ == '__main__':# 6. 训练循环train_accuracies = []test_accuracies = []print("开始训练...")for epoch in range(1, 101):  # 训练100个epochtrain_acc = train(epoch)test_acc = test(epoch)scheduler.step()train_accuracies.append(train_acc)test_accuracies.append(test_acc)print("训练完成!")# 7. 绘制准确率曲线plt.figure(figsize=(10, 5))plt.plot(range(1, 101), train_accuracies, label='Train Accuracy')plt.plot(range(1, 101), test_accuracies, label='Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('ResNet on CIFAR-10')plt.legend()plt.grid(True)plt.savefig('resnet_cifar10.png')plt.show()


文章转载自:

http://z8piaxYN.skrww.cn
http://cOtQLtRa.skrww.cn
http://q3NiXlNR.skrww.cn
http://voYX88dF.skrww.cn
http://cuteuwT0.skrww.cn
http://JYhBmjyr.skrww.cn
http://c0OGb8c9.skrww.cn
http://ptpxggY2.skrww.cn
http://WUsCVc2K.skrww.cn
http://3vcztRY5.skrww.cn
http://PtJYkKuM.skrww.cn
http://D7tJsd3F.skrww.cn
http://dgyeJjgt.skrww.cn
http://rfW6E5bw.skrww.cn
http://PufUfKfd.skrww.cn
http://sdSxFTCh.skrww.cn
http://Wj8p4umg.skrww.cn
http://X7EFcGS6.skrww.cn
http://03F5rAl8.skrww.cn
http://verJpDxP.skrww.cn
http://gzN23VcX.skrww.cn
http://NLrCrNBw.skrww.cn
http://vLq7u2KZ.skrww.cn
http://2OsRb82S.skrww.cn
http://lh4IZM8U.skrww.cn
http://5ewz6MPx.skrww.cn
http://AG6RMwvq.skrww.cn
http://dyHLBetZ.skrww.cn
http://ixWTfpwJ.skrww.cn
http://fxzR9iLg.skrww.cn
http://www.dtcms.com/a/374239.html

相关文章:

  • 注意力模块改进方法的原理及实现(MHA、MQA、GQA、MLA)
  • 蚂蚁 S21 Pro 220T矿机参数详解:SHA-256算法高效算力分析
  • 大模型测试包含哪些方面
  • 基于R语言的物种气候生态位动态量化与分布特征模拟
  • NGUI--Anchor组件和 事件系统
  • 基于Django的“酒店推荐系统”设计与开发(源码+数据库+文档+PPT)
  • OpenLayers数据源集成 -- 章节一:图像图层详解
  • 深度学习架构的硬件共生论:为什么GPU决定了AI的进化方向(Transformer、SSM、Mamba、MoE、CNN是什么、对比表格)
  • AndroidWorld+mobileRL
  • langchain4j笔记篇(阳哥)
  • 精简删除WIN11.24H2企业版映像内的OneDrive安装程序方法,卸载OneDrive组件
  • spring指南学习随记(一)
  • 安装配置简易VM虚拟机(CentOS 7)
  • 虚拟机中centos简单配置
  • commons-logging
  • 【小宁学习日记6 PCB】电路原理图
  • Rust位置表达式和值表达式
  • 对比:ClickHouse/MySQL/Apache Doris
  • 2025年学英语学习机选购指南
  • 浪涌测试主要用于评估电子设备或元器件在遭受短时高强度电压 / 电流冲击(浪涌)时的耐受能力
  • ANDROID,Jetpack Compose, 贪吃蛇小游戏Demo
  • html中列表和表格的使用
  • MyBatis-Plus 深度解析:IService 接口全指南
  • iPaaS 如何帮助 CIO 减少 50% 的集成成本?
  • [运动控制]PID算法再深入--多环组合控制
  • llm的一点学习笔记
  • JVM详解(一)--JVM和Java体系结构
  • Java字符串处理:String、StringBuilder与StringBuffer
  • SQL 注入与防御-第十章:确认并从 SQL 注入攻击中恢复
  • MCP(模型上下文协议)入门教程1