当前位置: 首页 > news >正文

DAY 47 注意力热图可视化

知识点回顾:

热力图

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import cv2
from torch.utils.data import DataLoader
from PIL import Image# 定义通道注意力模块
class ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction_ratio),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction_ratio, in_channels))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()# 平均池化分支avg_out = self.fc(self.avg_pool(x).view(b, c))# 最大池化分支max_out = self.fc(self.max_pool(x).view(b, c))# 合并分支out = avg_out + max_out# 生成通道权重scale = self.sigmoid(out).view(b, c, 1, 1)return x * scale# 定义CNN模型(包含通道注意力)
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()  # ---------------------- 第一个卷积块 ----------------------self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()self.ca1 = ChannelAttention(in_channels=32, reduction_ratio=16)  # 通道注意力模块self.pool1 = nn.MaxPool2d(2, 2)  # ---------------------- 第二个卷积块 ----------------------self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.ca2 = ChannelAttention(in_channels=64, reduction_ratio=16)  # 通道注意力模块self.pool2 = nn.MaxPool2d(2)  # ---------------------- 第三个卷积块 ----------------------self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.ca3 = ChannelAttention(in_channels=128, reduction_ratio=16)  # 通道注意力模块self.pool3 = nn.MaxPool2d(2)  # ---------------------- 全连接层(分类器) ----------------------self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, 10)self.relu_fc = nn.ReLU()  # 专门用于全连接层的ReLUdef forward(self, x):# 保存中间层输出用于可视化intermediate_outputs = {}# ---------- 卷积块1处理 ----------x = self.conv1(x)intermediate_outputs['conv1'] = x  # 保存卷积层输出x = self.bn1(x)x = self.relu1(x)x = self.ca1(x)  # 应用通道注意力intermediate_outputs['ca1'] = x  # 保存通道注意力后输出x = self.pool1(x)# ---------- 卷积块2处理 ----------x = self.conv2(x)intermediate_outputs['conv2'] = xx = self.bn2(x)x = self.relu2(x)x = self.ca2(x)  # 应用通道注意力intermediate_outputs['ca2'] = xx = self.pool2(x)# ---------- 卷积块3处理 ----------x = self.conv3(x)intermediate_outputs['conv3'] = xx = self.bn3(x)x = self.relu3(x)x = self.ca3(x)  # 应用通道注意力intermediate_outputs['ca3'] = xx = self.pool3(x)# ---------- 展平与全连接层 ----------x = x.view(-1, 128 * 4 * 4)x = self.fc1(x)x = self.relu_fc(x)x = self.dropout(x)x = self.fc2(x)return x, intermediate_outputs# Grad-CAM生成函数
def generate_gradcam(model, image, target_class, layer):model.eval()# 注册钩子获取特征图和梯度feature_maps = Nonegradients = Nonedef forward_hook(module, input, output):nonlocal feature_mapsfeature_maps = output.detach()def backward_hook(module, grad_in, grad_out):nonlocal gradientsgradients = grad_out[0].detach()# 注册钩子handle_forward = layer.register_forward_hook(forward_hook)handle_backward = layer.register_backward_hook(backward_hook)# 前向传播output, _ = model(image)# 反向传播获取梯度model.zero_grad()one_hot = torch.zeros_like(output)one_hot[0][target_class] = 1output.backward(gradient=one_hot)# 移除钩子handle_forward.remove()handle_backward.remove()# 计算权重pooled_gradients = torch.mean(gradients, dim=[2, 3])# 加权特征图weighted_feature_maps = torch.zeros_like(feature_maps)for i in range(feature_maps.size(1)):weighted_feature_maps[:, i, :, :] = feature_maps[:, i, :, :] * pooled_gradients[:, i]# 生成热图heatmap = torch.mean(weighted_feature_maps, dim=1).squeeze()heatmap = F.relu(heatmap)  # 只保留正影响heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())  # 归一化return heatmap.cpu().numpy()# 训练函数
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=10):best_accuracy = 0.0for epoch in range(epochs):model.train()running_loss = 0.0for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 前向传播outputs, _ = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()# 每100个batch打印一次状态if i % 100 == 99:print(f'Epoch: {epoch+1}/{epochs} | Batch: {i+1}/{len(train_loader)} | 'f'单Batch损失: {loss.item():.4f} | 累计平均损失: {running_loss/(i+1):.4f}')# 每个epoch结束后在测试集上评估model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs, _ = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Epoch {epoch+1}/{epochs} 完成 | 'f'训练准确率: {accuracy:.2f}%')# 更新学习率scheduler.step(loss)# 保存最佳模型if accuracy > best_accuracy:best_accuracy = accuracytorch.save(model.state_dict(), 'cifar_cnn_with_ca_model.pth')return best_accuracy# 图像预处理
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 初始化模型
model = CNN().to(device)# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5
)# 训练模型
print("开始训练带通道注意力的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# 可视化不同层的热图
def visualize_heatmaps(model, test_loader, device):# 加载训练好的模型model.load_state_dict(torch.load('cifar_cnn_with_ca_model.pth', map_location=device))model.eval()# 获取测试图像images, labels = next(iter(test_loader))image = images[0].unsqueeze(0).to(device)true_label = labels[0].item()# 预测类别with torch.no_grad():output, _ = model(image)pred_class = output.argmax().item()# 选择不同的层进行可视化layers = {'conv1': model.conv1,'ca1': model.ca1,'conv2': model.conv2,'ca2': model.ca2,'conv3': model.conv3,'ca3': model.ca3}# 可视化设置plt.figure(figsize=(20, 15))# 显示原始图像original_image = image.squeeze().permute(1, 2, 0).cpu().numpy()original_image = (original_image * 0.5 + 0.5) * 255original_image = original_image.astype(np.uint8)plt.subplot(3, 4, 1)plt.imshow(original_image)plt.title(f'Original Image\nTrue: {test_dataset.classes[true_label]}\nPred: {test_dataset.classes[pred_class]}')plt.axis('off')# 生成并显示各层的热图for i, (layer_name, layer) in enumerate(layers.items()):heatmap = generate_gradcam(model, image, pred_class, layer)# 调整热图大小以匹配原始图像heatmap = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)# 叠加热图到原始图像superimposed_img = cv2.addWeighted(original_image, 0.6, heatmap, 0.4, 0)# 显示结果plt.subplot(3, 4, i+2)plt.imshow(superimposed_img)plt.title(f'Grad-CAM: {layer_name}')plt.axis('off')plt.tight_layout()plt.savefig('cnn_ca_heatmaps_comparison.png')plt.show()# 可视化热图
visualize_heatmaps(model, test_loader, device)

@浙大疏锦行

http://www.dtcms.com/a/262975.html

相关文章:

  • 有些Android旧平台,在Settings菜单里的,设置-电池菜单下,没有电池使用数据,如何处理
  • RK3568平台开发系列讲解:HDMI显示驱动
  • 六自由度按摩机器人 MATLAB 仿真
  • HarmonyOS NEXT仓颉开发语言实战案例:电影App
  • Windows VMWare Centos Docker部署Nginx并配置对Springboot应用的访问代理
  • k8s一键部署tongweb7容器版脚本(by why+lqw)
  • 车辆工程中的压力传感技术:MEMS与薄膜传感器的实战应用
  • 22.安卓逆向2-frida hook技术-app使用非http协议抓不到包解决方式
  • Linux 安装使用教程
  • Pytest自动化测试框架入门?
  • Kafka 核心机制面试题--自问自答
  • 在Flutter中生成App Bundle并上架Google Play
  • 「Java EE开发指南」如何用MyEclipse创建一个WEB项目?(三)
  • 鸿蒙NEXT-鸿蒙三层架构搭建,嵌入HMRouter,实现便捷跳转,新手攻略。(2/3)
  • Flutter视频压缩插件video_compressffmpeg_kit_flutter_new
  • Memcached 安装使用教程
  • Flutter插件ios_pod
  • httpd-devel 与服务无关
  • Java历史:从橡树到火星探索,从微软法律战到Spring、Gradle
  • [6-02-01].第05节:配置文件 - 读取配置文件的内容
  • 一、(基础)构建一个简单的 LangChain 应用
  • 对称非对称加密,https和http,https通讯原理,Charles抓包原理
  • macos 使用 vllm 启动模型
  • WIFI 低功耗保活知识系列---三.WiFi AP如何广播自己的缓存区信息
  • OpenCV CUDA模块设备层----计算向量的平方根函数sqrt
  • 基于Spring Boot的绿园社区团购系统的设计与实现
  • Python 安装使用教程
  • Spring Boot 启动性能优化实战指南
  • 基于 SpringBoot+Vue.js+ElementUI 的 Cosplay 论坛设计与实现7000字论文
  • 【硬核数学】2.7 理论与现实的鸿沟:深度学习的数值稳定性挑战《从零构建机器学习、深度学习到LLM的数学认知》