python打卡day47@浙大疏锦行
昨天代码中注意力热图的部分顺移至今天
知识点回顾:
热力图
作业:对比不同卷积层热图可视化的结果
以下是不同卷积层特征图可视化的对比实现:
import torch
import matplotlib.pyplot as pltdef compare_conv_layers(model, input_tensor):# 注册多个钩子获取不同层特征图layer_outputs = {}def save_output(layer_name):def hook(module, input, output):layer_outputs[layer_name] = output.detach().cpu()return hook# 选择三个不同卷积层hooks = [model.layer1[0].conv1.register_forward_hook(save_output('layer1_conv')),model.layer2[0].conv1.register_forward_hook(save_output('layer2_conv')),model.layer3[0].conv1.register_forward_hook(save_output('layer3_conv'))]# 前向传播with torch.no_grad():model(input_tensor.unsqueeze(0))# 移除钩子for hook in hooks:hook.remove()# 可视化对比fig, axes = plt.subplots(3, 5, figsize=(20, 12))for row, (layer_name, features) in enumerate(layer_outputs.items()):for col in range(5):axes[row, col].imshow(features[0, col].numpy(), cmap='viridis')axes[row, col].set_title(f"{layer_name}\nch{col}", fontsize=8)axes[row, col].axis('off')plt.tight_layout()plt.show()# 使用示例
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()
input_img = torch.randn(3, 224, 224) # 替换为实际输入图像
compare_conv_layers(model, input_img)
输出结果如图所示: