- 进行特征图可视化的时候,修改模型的forward函数来进行可视化十分麻烦,还需要想办法把特征图传出来,在模型层层调用的时候更加麻烦,要修改多个无关的嵌套,还容易引起bug。这里提供了一个简单的范式,只需要一个vis.py文件(可从train.py或者test.py修改而来),无需修改模型的定义文件,即可实现特征图的可视化。
- 该做法的核心思想是两点,第一点是利用vis.py里面的全局变量来存储特征图以及网络层数等,第二点是直接在vis.py里面重写需要可视化特征图的module的forward函数,以用最小的改动将特征图传递出来。
- 这段代码还提供了利用sns.heatmap可视化特征图的例子,整体代码如下:
import argparse
import os
import math
from functools import partialimport yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdmimport datasets
import models
import utils
from torchvision import transforms
from PIL import Image
import random
import numpy as npimport matplotlib.pyplot as plt
import seaborn as sns
from models.models_meta import mGAttn
from einops import rearrange as rearrange@torch.no_grad()
def vis_mod(mod_dict, path, name):for i in [3, 13]:scale = mod_dict[f'layer_{i}_scale'] offset = mod_dict[f'layer_{i}_offset']name_scale_i = name+f'scale_{i}_avg.png'vis_feature(scale, path, name_scale_i)vis_feature_each_channel(scale, path, name_scale_i)name_offset_i = name+f'offset_{i}_avg.png'vis_feature(offset, path, name_offset_i) vis_feature_each_channel(offset, path, name_offset_i) return def vis_feature(feature, path, name):feature = feature[0, ...]plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature, dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name))plt.close()for i in range(feature.size(0)//32):plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature[i*32:(i+1)*32, :, :], dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'avg_{i}')))plt.close()def vis_feature_each_channel(feature, path, name):feature = feature[0, ...]for i in range(feature.shape[0]):plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(feature[i].cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'channel_{i}')))plt.close()global_feature_maps = {}
def modify_forward_for_mGAttn(module):if isinstance(module, mGAttn):def modified_forward(self, x):"""x: b * c * h * w"""curr_layer = global_feature_maps['curr_layer']if curr_layer in [3, 13]:B, h, Ch, N = offset.shapeglobal_feature_maps[f'layer_{curr_layer}'] = feature.view(B, h, He, We)global_feature_maps['curr_layer'] = curr_layer+1return outmodule.forward = modified_forward.__get__(module)for child_module in module.children():modify_forward_for_mGAttn(child_module)if __name__ == '__main__':modify_forward_for_mGAttn(model)myeval(model)