当前位置: 首页 > news >正文

Python打卡:Day50

现在我们思考下,是否可以对于预训练模型增加模块来优化其效果,这里我们会遇到一个问题

预训练模型的结构和权重是固定的,如果修改其中的模型结构,是否会大幅影响其性能。其次是训练的时候如何训练才可以更好的避免破坏原有的特征提取器的参数。

所以今天的内容,我们需要回答2个问题。

1. resnet18中如何插入cbam模块?

2. 采用什么样的预训练策略,能够更好的提高效率?

可以很明显的想到,如果是resnet18+cbam模块,那么大多数地方的代码都是可以复用的,模型定义部分需要重写。先继续之前的代码

所以很容易的想到之前第一次使用resnet的预训练策略:先冻结预训练层,然后训练其他层。之前的其它是全连接层(分类头),现在其它层还包含了每一个残差块中的cbam注意力层。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 定义通道注意力
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):"""通道注意力机制初始化参数:in_channels: 输入特征图的通道数ratio: 降维比例,用于减少参数量,默认为16"""super().__init__()# 全局平均池化,将每个通道的特征图压缩为1x1,保留通道间的平均值信息self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局最大池化,将每个通道的特征图压缩为1x1,保留通道间的最显著特征self.max_pool = nn.AdaptiveMaxPool2d(1)# 共享全连接层,用于学习通道间的关系# 先降维(除以ratio),再通过ReLU激活,最后升维回原始通道数self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降维层nn.ReLU(),  # 非线性激活函数nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升维层)# Sigmoid函数将输出映射到0-1之间,作为各通道的权重self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播函数参数:x: 输入特征图,形状为 [batch_size, channels, height, width]返回:调整后的特征图,通道权重已应用"""# 获取输入特征图的维度信息,这是一种元组的解包写法b, c, h, w = x.shape# 对平均池化结果进行处理:展平后通过全连接网络avg_out = self.fc(self.avg_pool(x).view(b, c))# 对最大池化结果进行处理:展平后通过全连接网络max_out = self.fc(self.max_pool(x).view(b, c))# 将平均池化和最大池化的结果相加并通过sigmoid函数得到通道权重attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)# 将注意力权重与原始特征相乘,增强重要通道,抑制不重要通道return x * attention #这个运算是pytorch的广播机制## 空间注意力模块
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# 通道维度池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)attention = self.conv(pool_out)  # 卷积提取空间特征return x * self.sigmoid(attention)  # 特征与空间权重相乘## CBAM模块
class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):x = self.channel_attn(x)x = self.spatial_attn(x)return ximport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 数据预处理(与原代码一致)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载数据集(与原代码一致)
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

@浙大疏锦行 

http://www.dtcms.com/a/273500.html

相关文章:

  • 将Blender、Three.js与Cesium集成构建物联网3D可视化系统
  • uniapp类似抖音视频滑动
  • GNhao,获取跨境手机SIM卡跨境通信新选择!
  • JAVA面试宝典 -《Spring Boot 自动配置魔法解密》
  • 森马服饰从 Elasticsearch 到阿里云 SelectDB 的架构演进之路
  • 【LeetCode 热题 100】146. LRU 缓存——哈希表+双向链表
  • [特殊字符]远程服务器配置pytorch环境
  • 闲庭信步使用图像验证平台加速FPGA的开发:第八课——图像数据的行缓存
  • 基于ASP.NET MVC+SQLite开发的一套(Web)图书管理系统
  • ContextMenu的Item如何绑定命令
  • 手机恢复出厂设置怎么找回数据?Aiseesoft FoneLab for Android数据恢复工具分享
  • 【音视频】HLS拉流抓包分析
  • 【mac】快捷键使用指南
  • Java 深入解析:JVM对象创建与内存机制全景图
  • uni-app获取手机当前连接的WIFI名称
  • 如何将文件从OPPO手机传输到电脑
  • 视频人脸处理——人脸面部动作提取
  • 虹科分享 | 告别实体钥匙!数字钥匙正在重构你的用车体验
  • 计算机毕业设计ssm基于JavaScript的餐厅点餐系统 SSM+Vue智慧餐厅在线点餐管理平台 JavaWeb前后端分离式餐饮点餐与桌台调度系统
  • 【前端】【组件库开发】【原理】【无框架开发】现代网页弹窗开发指南:从基础到优化
  • Python day58
  • rom定制系列------红米note10 5G版camellia原生安卓14批量线刷 miui安卓11修改型号root版
  • php use 命名空间与 spl_autoload_register的关系
  • Microsoft Word 中 .doc 和 .docx 的区别
  • 重构下一代智能电池“神经中枢”:GCKontrol定义高性能BMS系统级设计标杆
  • 2025年渗透测试面试题总结-2025年HW(护网面试) 41(题目+回答)
  • 基于开源AI智能名片链动2+1模式与S2B2C商城小程序的渠道选择策略研究
  • SpringDataRedis入门
  • 慕尚花坊项目笔记
  • ADSP-21489用SigmaStudio+(SS+)来做开发的详解六、T的用法