当前位置: 首页 > news >正文

特征图可视化代码

  • 进行特征图可视化的时候,修改模型的forward函数来进行可视化十分麻烦,还需要想办法把特征图传出来,在模型层层调用的时候更加麻烦,要修改多个无关的嵌套,还容易引起bug。这里提供了一个简单的范式,只需要一个vis.py文件(可从train.py或者test.py修改而来),无需修改模型的定义文件,即可实现特征图的可视化。
  • 该做法的核心思想是两点,第一点是利用vis.py里面的全局变量来存储特征图以及网络层数等,第二点是直接在vis.py里面重写需要可视化特征图的module的forward函数,以用最小的改动将特征图传递出来。
  • 这段代码还提供了利用sns.heatmap可视化特征图的例子,整体代码如下:
# vis talor mod
import argparse
import os
import math
from functools import partialimport yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdmimport datasets
import models
import utils
from torchvision import transforms
from PIL import Image
import random
import numpy as npimport matplotlib.pyplot as plt
import seaborn as sns
from models.models_meta import mGAttn
from einops import rearrange as rearrange@torch.no_grad()
def vis_mod(mod_dict, path, name):for i in [3, 13]:#[3,5,19]:scale = mod_dict[f'layer_{i}_scale'] offset = mod_dict[f'layer_{i}_offset']name_scale_i = name+f'scale_{i}_avg.png'vis_feature(scale, path, name_scale_i)vis_feature_each_channel(scale, path, name_scale_i)name_offset_i = name+f'offset_{i}_avg.png'vis_feature(offset, path, name_offset_i)   vis_feature_each_channel(offset, path, name_offset_i)   return def vis_feature(feature, path, name):feature = feature[0, ...]# lower_percentile = 0.1# upper_percentile = 0.9# for i in range(feature.shape[0]):#     feature_i = feature[i].view(-1)#     lower_bound = torch.quantile(feature_i, lower_percentile)#     upper_bound = torch.quantile(feature_i, upper_percentile)#     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)# plt.figure()plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature, dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name))plt.close()for i in range(feature.size(0)//32):plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature[i*32:(i+1)*32, :, :], dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'avg_{i}')))plt.close()def vis_feature_each_channel(feature, path, name):feature = feature[0, ...]# lower_percentile = 0.1# upper_percentile = 0.9# for i in range(feature.shape[0]):#     feature_i = feature[i].view(-1)#     lower_bound = torch.quantile(feature_i, lower_percentile)#     upper_bound = torch.quantile(feature_i, upper_percentile)#     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)for i in range(feature.shape[0]):# plt.figure()plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(feature[i].cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'channel_{i}')))plt.close()global_feature_maps = {}
def modify_forward_for_mGAttn(module):if isinstance(module, mGAttn):# original_forward = module.forwarddef modified_forward(self, x):"""x: b * c * h * w"""# 这里省略了模型原有的一些forward过程curr_layer = global_feature_maps['curr_layer']if curr_layer in [3, 13]:#[1,5,19]:B, h, Ch, N = offset.shapeglobal_feature_maps[f'layer_{curr_layer}'] = feature.view(B, h, He, We)global_feature_maps['curr_layer'] = curr_layer+1# 这里省略了模型原有的一些forward过程return outmodule.forward =  modified_forward.__get__(module)for child_module in module.children():modify_forward_for_mGAttn(child_module)if __name__ == '__main__':# 这里省略了一些模型的定义过程# modify forward for mGAttnmodify_forward_for_mGAttn(model)# 接着按自己的方式直接调用模型即可myeval(model)

相关文章:

  • Java中的ConcurrentHashMap的使用与原理
  • Ros真(node?package?)
  • DeepSeek部署实战:常见问题与高效解决方案全解析
  • 从零开始的数据结构教程(七) 回溯算法
  • PCIE之Lane Reserval通道out of oder调换顺序
  • TDengine 集群运行监控
  • Kubernetes RBAC权限控制:从入门到实战
  • kafka学习笔记(三、消费者Consumer使用教程——配置参数大全及性能调优)
  • 【PCI】PCI入门介绍(包含部分PCIe讲解)
  • win11安装踩坑笔记 win11 u盘安装
  • 67.实现AI流式回答的后端实现(2)
  • Windows下编译zlib
  • 属性映射框架-MapStruct
  • 使用交叉编译工具提示stubs-32.h:7:11: fatal error: gnu/stubs-soft.h: 没有那个文件或目录的解决办法
  • 【LaTex公式】在Latex公式中模拟表格
  • 34、请求处理-【源码分析】-Model、Map原理
  • VulnStack|红日靶场——红队评估四
  • python中将一个列表样式的字符串转换成真正列表的办法以及json.dumps()和 json.loads()
  • SAR ADC 同步逻辑设计
  • 2. 手写数字预测 gui版
  • 服装怎么做网站推广/代写文章兼职
  • 企业网站建设案例分析/百度竞价推广教程
  • 做数学题目在哪个网站好/seo的推广技巧
  • 电子商务网站制作步骤/百度seo优化方法
  • 公司简介模板免费ppt下载/东莞seo广告宣传
  • 做网站系统开发的意义/怎么制作网页页面