当前位置: 首页 > news >正文

特征图可视化代码

  • 进行特征图可视化的时候,修改模型的forward函数来进行可视化十分麻烦,还需要想办法把特征图传出来,在模型层层调用的时候更加麻烦,要修改多个无关的嵌套,还容易引起bug。这里提供了一个简单的范式,只需要一个vis.py文件(可从train.py或者test.py修改而来),无需修改模型的定义文件,即可实现特征图的可视化。
  • 该做法的核心思想是两点,第一点是利用vis.py里面的全局变量来存储特征图以及网络层数等,第二点是直接在vis.py里面重写需要可视化特征图的module的forward函数,以用最小的改动将特征图传递出来。
  • 这段代码还提供了利用sns.heatmap可视化特征图的例子,整体代码如下:
# vis talor mod
import argparse
import os
import math
from functools import partialimport yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdmimport datasets
import models
import utils
from torchvision import transforms
from PIL import Image
import random
import numpy as npimport matplotlib.pyplot as plt
import seaborn as sns
from models.models_meta import mGAttn
from einops import rearrange as rearrange@torch.no_grad()
def vis_mod(mod_dict, path, name):for i in [3, 13]:#[3,5,19]:scale = mod_dict[f'layer_{i}_scale'] offset = mod_dict[f'layer_{i}_offset']name_scale_i = name+f'scale_{i}_avg.png'vis_feature(scale, path, name_scale_i)vis_feature_each_channel(scale, path, name_scale_i)name_offset_i = name+f'offset_{i}_avg.png'vis_feature(offset, path, name_offset_i)   vis_feature_each_channel(offset, path, name_offset_i)   return def vis_feature(feature, path, name):feature = feature[0, ...]# lower_percentile = 0.1# upper_percentile = 0.9# for i in range(feature.shape[0]):#     feature_i = feature[i].view(-1)#     lower_bound = torch.quantile(feature_i, lower_percentile)#     upper_bound = torch.quantile(feature_i, upper_percentile)#     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)# plt.figure()plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature, dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name))plt.close()for i in range(feature.size(0)//32):plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature[i*32:(i+1)*32, :, :], dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'avg_{i}')))plt.close()def vis_feature_each_channel(feature, path, name):feature = feature[0, ...]# lower_percentile = 0.1# upper_percentile = 0.9# for i in range(feature.shape[0]):#     feature_i = feature[i].view(-1)#     lower_bound = torch.quantile(feature_i, lower_percentile)#     upper_bound = torch.quantile(feature_i, upper_percentile)#     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)for i in range(feature.shape[0]):# plt.figure()plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(feature[i].cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'channel_{i}')))plt.close()global_feature_maps = {}
def modify_forward_for_mGAttn(module):if isinstance(module, mGAttn):# original_forward = module.forwarddef modified_forward(self, x):"""x: b * c * h * w"""# 这里省略了模型原有的一些forward过程curr_layer = global_feature_maps['curr_layer']if curr_layer in [3, 13]:#[1,5,19]:B, h, Ch, N = offset.shapeglobal_feature_maps[f'layer_{curr_layer}'] = feature.view(B, h, He, We)global_feature_maps['curr_layer'] = curr_layer+1# 这里省略了模型原有的一些forward过程return outmodule.forward =  modified_forward.__get__(module)for child_module in module.children():modify_forward_for_mGAttn(child_module)if __name__ == '__main__':# 这里省略了一些模型的定义过程# modify forward for mGAttnmodify_forward_for_mGAttn(model)# 接着按自己的方式直接调用模型即可myeval(model)
http://www.dtcms.com/a/224914.html

相关文章:

  • Java中的ConcurrentHashMap的使用与原理
  • Ros真(node?package?)
  • DeepSeek部署实战:常见问题与高效解决方案全解析
  • 从零开始的数据结构教程(七) 回溯算法
  • PCIE之Lane Reserval通道out of oder调换顺序
  • TDengine 集群运行监控
  • Kubernetes RBAC权限控制:从入门到实战
  • kafka学习笔记(三、消费者Consumer使用教程——配置参数大全及性能调优)
  • 【PCI】PCI入门介绍(包含部分PCIe讲解)
  • win11安装踩坑笔记 win11 u盘安装
  • 67.实现AI流式回答的后端实现(2)
  • Windows下编译zlib
  • 属性映射框架-MapStruct
  • 使用交叉编译工具提示stubs-32.h:7:11: fatal error: gnu/stubs-soft.h: 没有那个文件或目录的解决办法
  • 【LaTex公式】在Latex公式中模拟表格
  • 34、请求处理-【源码分析】-Model、Map原理
  • VulnStack|红日靶场——红队评估四
  • python中将一个列表样式的字符串转换成真正列表的办法以及json.dumps()和 json.loads()
  • SAR ADC 同步逻辑设计
  • 2. 手写数字预测 gui版
  • 声纹技术体系:从理论基础到工程实践的完整技术架构
  • VAE在扩散模型中的技术实现与应用
  • 算法训练第三天
  • 跑步前热身动作
  • Python应用for循环遍历寻b
  • RAGFlow从理论到实战的检索增强生成指南
  • 在win10/11下Node.js安装配置教程
  • Java 认识异常
  • 桥 接 模 式
  • 介绍一种LDPC码译码器