DAY 47 注意力热图可视化
知识点回顾:
热力图
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import cv2
from torch.utils.data import DataLoader
from PIL import Image# 定义通道注意力模块
class ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction_ratio),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction_ratio, in_channels))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()# 平均池化分支avg_out = self.fc(self.avg_pool(x).view(b, c))# 最大池化分支max_out = self.fc(self.max_pool(x).view(b, c))# 合并分支out = avg_out + max_out# 生成通道权重scale = self.sigmoid(out).view(b, c, 1, 1)return x * scale# 定义CNN模型(包含通道注意力)
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__() # ---------------------- 第一个卷积块 ----------------------self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()self.ca1 = ChannelAttention(in_channels=32, reduction_ratio=16) # 通道注意力模块self.pool1 = nn.MaxPool2d(2, 2) # ---------------------- 第二个卷积块 ----------------------self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.ca2 = ChannelAttention(in_channels=64, reduction_ratio=16) # 通道注意力模块self.pool2 = nn.MaxPool2d(2) # ---------------------- 第三个卷积块 ----------------------self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.ca3 = ChannelAttention(in_channels=128, reduction_ratio=16) # 通道注意力模块self.pool3 = nn.MaxPool2d(2) # ---------------------- 全连接层(分类器) ----------------------self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, 10)self.relu_fc = nn.ReLU() # 专门用于全连接层的ReLUdef forward(self, x):# 保存中间层输出用于可视化intermediate_outputs = {}# ---------- 卷积块1处理 ----------x = self.conv1(x)intermediate_outputs['conv1'] = x # 保存卷积层输出x = self.bn1(x)x = self.relu1(x)x = self.ca1(x) # 应用通道注意力intermediate_outputs['ca1'] = x # 保存通道注意力后输出x = self.pool1(x)# ---------- 卷积块2处理 ----------x = self.conv2(x)intermediate_outputs['conv2'] = xx = self.bn2(x)x = self.relu2(x)x = self.ca2(x) # 应用通道注意力intermediate_outputs['ca2'] = xx = self.pool2(x)# ---------- 卷积块3处理 ----------x = self.conv3(x)intermediate_outputs['conv3'] = xx = self.bn3(x)x = self.relu3(x)x = self.ca3(x) # 应用通道注意力intermediate_outputs['ca3'] = xx = self.pool3(x)# ---------- 展平与全连接层 ----------x = x.view(-1, 128 * 4 * 4)x = self.fc1(x)x = self.relu_fc(x)x = self.dropout(x)x = self.fc2(x)return x, intermediate_outputs# Grad-CAM生成函数
def generate_gradcam(model, image, target_class, layer):model.eval()# 注册钩子获取特征图和梯度feature_maps = Nonegradients = Nonedef forward_hook(module, input, output):nonlocal feature_mapsfeature_maps = output.detach()def backward_hook(module, grad_in, grad_out):nonlocal gradientsgradients = grad_out[0].detach()# 注册钩子handle_forward = layer.register_forward_hook(forward_hook)handle_backward = layer.register_backward_hook(backward_hook)# 前向传播output, _ = model(image)# 反向传播获取梯度model.zero_grad()one_hot = torch.zeros_like(output)one_hot[0][target_class] = 1output.backward(gradient=one_hot)# 移除钩子handle_forward.remove()handle_backward.remove()# 计算权重pooled_gradients = torch.mean(gradients, dim=[2, 3])# 加权特征图weighted_feature_maps = torch.zeros_like(feature_maps)for i in range(feature_maps.size(1)):weighted_feature_maps[:, i, :, :] = feature_maps[:, i, :, :] * pooled_gradients[:, i]# 生成热图heatmap = torch.mean(weighted_feature_maps, dim=1).squeeze()heatmap = F.relu(heatmap) # 只保留正影响heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) # 归一化return heatmap.cpu().numpy()# 训练函数
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=10):best_accuracy = 0.0for epoch in range(epochs):model.train()running_loss = 0.0for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 前向传播outputs, _ = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()# 每100个batch打印一次状态if i % 100 == 99:print(f'Epoch: {epoch+1}/{epochs} | Batch: {i+1}/{len(train_loader)} | 'f'单Batch损失: {loss.item():.4f} | 累计平均损失: {running_loss/(i+1):.4f}')# 每个epoch结束后在测试集上评估model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs, _ = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Epoch {epoch+1}/{epochs} 完成 | 'f'训练准确率: {accuracy:.2f}%')# 更新学习率scheduler.step(loss)# 保存最佳模型if accuracy > best_accuracy:best_accuracy = accuracytorch.save(model.state_dict(), 'cifar_cnn_with_ca_model.pth')return best_accuracy# 图像预处理
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 初始化模型
model = CNN().to(device)# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5
)# 训练模型
print("开始训练带通道注意力的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# 可视化不同层的热图
def visualize_heatmaps(model, test_loader, device):# 加载训练好的模型model.load_state_dict(torch.load('cifar_cnn_with_ca_model.pth', map_location=device))model.eval()# 获取测试图像images, labels = next(iter(test_loader))image = images[0].unsqueeze(0).to(device)true_label = labels[0].item()# 预测类别with torch.no_grad():output, _ = model(image)pred_class = output.argmax().item()# 选择不同的层进行可视化layers = {'conv1': model.conv1,'ca1': model.ca1,'conv2': model.conv2,'ca2': model.ca2,'conv3': model.conv3,'ca3': model.ca3}# 可视化设置plt.figure(figsize=(20, 15))# 显示原始图像original_image = image.squeeze().permute(1, 2, 0).cpu().numpy()original_image = (original_image * 0.5 + 0.5) * 255original_image = original_image.astype(np.uint8)plt.subplot(3, 4, 1)plt.imshow(original_image)plt.title(f'Original Image\nTrue: {test_dataset.classes[true_label]}\nPred: {test_dataset.classes[pred_class]}')plt.axis('off')# 生成并显示各层的热图for i, (layer_name, layer) in enumerate(layers.items()):heatmap = generate_gradcam(model, image, pred_class, layer)# 调整热图大小以匹配原始图像heatmap = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)# 叠加热图到原始图像superimposed_img = cv2.addWeighted(original_image, 0.6, heatmap, 0.4, 0)# 显示结果plt.subplot(3, 4, i+2)plt.imshow(superimposed_img)plt.title(f'Grad-CAM: {layer_name}')plt.axis('off')plt.tight_layout()plt.savefig('cnn_ca_heatmaps_comparison.png')plt.show()# 可视化热图
visualize_heatmaps(model, test_loader, device)
@浙大疏锦行