当前位置: 首页 > news >正文

关于可视化卷积核和特征图的深度理解

可视化卷积核和特征图,我们可以使用 PyTorch 结合 Matplotlib 来实现。


一、源代码

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torchvision import models, transforms
from PIL import Image
import time# 修复字体问题:使用系统可用的中文字体或默认字体
try:# 尝试设置中文字体plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
except:# 若中文字体不可用,使用默认字体避免警告plt.rcParams["font.family"] = ["sans-serif"]plt.rcParams['axes.unicode_minus'] = False# 1. 卷积核可视化函数
def visualize_kernels(kernel, title="卷积核可视化"):"""可视化卷积核"""out_channels, in_channels, k_h, k_w = kernel.shape# 创建画布fig, axes = plt.subplots(in_channels, out_channels, figsize=(out_channels * 2, in_channels * 2))fig.suptitle(title, fontsize=16)# 确保axes是可迭代的二维数组if in_channels == 1 and out_channels == 1:axes = np.array([[axes]])elif in_channels == 1:axes = np.array([axes])elif out_channels == 1:axes = np.array([[ax] for ax in axes])# 遍历并显示每个卷积核for i in range(in_channels):for j in range(out_channels):ax = axes[i, j]# 标准化卷积核值到0-1范围kernel_data = kernel[j, i].detach().numpy()kernel_data = (kernel_data - kernel_data.min()) / (kernel_data.max() - kernel_data.min() + 1e-8)ax.imshow(kernel_data, cmap='gray')ax.set_title(f'输入通道{i}, 输出通道{j}')ax.axis('off')plt.tight_layout()plt.subplots_adjust(top=0.9)print(f"显示{title},关闭窗口后将显示特征图...")plt.show()# 2. 特征图可视化函数
def visualize_feature_maps(feature_maps, title="特征图可视化", num_cols=8):"""可视化特征图"""batch_size, channels, h, w = feature_maps.shapefeature_maps = feature_maps[0]  # 只取第一个样本# 计算布局num_rows = (channels + num_cols - 1) // num_colsfig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))fig.suptitle(title, fontsize=16)# 确保axes是二维数组if num_rows == 1 and num_cols == 1:axes = np.array([[axes]])elif num_rows == 1:axes = np.array([axes])# 显示每个特征图for i in range(channels):row = i // num_colscol = i % num_colsax = axes[row, col]# 标准化特征图值feature_map = feature_maps[i].detach().numpy()feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8)ax.imshow(feature_map, cmap='viridis')ax.set_title(f'特征图 {i + 1}')ax.axis('off')# 隐藏未使用的子图for i in range(channels, num_rows * num_cols):row = i // num_colscol = i % num_colsaxes[row, col].axis('off')plt.tight_layout()plt.subplots_adjust(top=0.9)print(f"显示{title}")plt.show()# 3. 自定义卷积层演示
def demo_custom_conv():print("=== 开始自定义卷积层可视化演示 ===")# 创建卷积层 (1输入通道, 4输出通道, 3x3卷积核)conv_layer = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1)print(f"卷积层参数: {conv_layer}")# 可视化卷积核visualize_kernels(conv_layer.weight, title="自定义卷积层的卷积核")# 创建输入数据 (1个样本, 1通道, 28x28大小)input_data = torch.randn(1, 1, 28, 28)print(f"输入数据形状: {input_data.shape}")# 计算特征图with torch.no_grad():  # 关闭梯度计算feature_maps = conv_layer(input_data)print(f"特征图形状: {feature_maps.shape}")  # 应为 [1, 4, 28, 28]# 显示特征图time.sleep(0.5)  # 短暂延迟确保窗口顺序显示visualize_feature_maps(feature_maps, title="自定义卷积层的特征图", num_cols=2)print("=== 自定义卷积层可视化演示结束 ===")# 4. 预训练模型演示
def demo_pretrained_model():print("\n=== 开始预训练模型可视化演示 ===")# 加载VGG16模型print("加载VGG16预训练模型...")model = models.vgg16(pretrained=True)model.eval()  # 评估模式# 获取第一个卷积层first_conv = model.features[0]print(f"第一个卷积层: {first_conv}")# 可视化卷积核visualize_kernels(first_conv.weight, title="VGG16第一个卷积层的卷积核")# 准备输入图像transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])try:# 尝试加载本地图像img_path = r"E:c39e6ff6-d65f-4a9b-b9ef-bf87e52b2665.png"img = Image.open(img_path).convert('RGB')print(f"成功加载图像: {img_path}")except Exception as e:# 加载失败时使用随机数据print(f"图像加载失败: {e},使用随机数据替代")img = torch.randn(3, 224, 224)  # 模拟RGB图像img = transform(img).unsqueeze(0)  # 添加batch维度print(f"处理后图像形状: {img.shape}")# 用钩子获取特征图feature_maps = Nonedef hook_fn(module, input, output):nonlocal feature_mapsfeature_maps = outputhook = first_conv.register_forward_hook(hook_fn)# 前向传播计算特征图with torch.no_grad():model(img)hook.remove()  # 移除钩子print(f"VGG16特征图形状: {feature_maps.shape}")# 显示特征图time.sleep(0.5)visualize_feature_maps(feature_maps, title="VGG16第一个卷积层的特征图", num_cols=10)print("=== 预训练模型可视化演示结束 ===")if __name__ == "__main__":# 运行自定义卷积层演示(会显示两个窗口:先卷积核,后特征图)demo_custom_conv()# 如需运行预训练模型演示,取消下面一行注释# demo_pretrained_model()

注:如果要使用自己的图片,可以取消demo_pretrained_model函数中相关代码的注释,并替换为自己的图片路径。


二、代码功能解析

上面的代码提供了两种主要的可视化功能,这段代码的核心目的是帮助你直观地看到卷积神经网络 (CNN) 的 "眼睛" 和 "所见之物",让抽象的卷积操作变得可视化、可理解。

1.卷积核可视化(模型的 "眼睛")

  • 将卷积核以图像形式展示出来
  • 支持多输入通道和多输出通道的卷积核
  • 自动标准化卷积核数值以获得更好的显示效果

卷积核就像模型的 "眼睛",每个卷积核负责识别输入中的特定模式(比如边缘、纹理、颜色等)。

代码中的visualize_kernels函数会把这些卷积核以图像形式画出来。比如:

  • 边缘检测的卷积核会呈现明暗交替的条纹
  • 纹理检测的卷积核会有特定的图案结构

通过可视化,你能直接看到:"原来模型是用这样的过滤器在看世界!"

2.特征图可视化(模型 "看到的内容")

  • 展示卷积操作后生成的特征图
  • 支持批量显示多个特征图
  • 自动调整布局,使显示更加美观

当输入图像经过卷积核处理后,会生成特征图 —— 这是模型 "看到" 的内容。

代码中的visualize_feature_maps函数会展示这些特征图:

  • 特征图上明亮的区域表示该位置与对应卷积核识别的模式匹配度高
  • 不同的特征图对应不同卷积核 "看到" 的信息

比如处理一张猫的图片,某个特征图可能会在猫的边缘区域特别亮(因为对应的卷积核擅长检测边缘)。


三、关键概念解释

  1. 卷积核可视化的意义卷积核可视化可以帮助我们理解模型学到的特征提取方式:

    • 浅层卷积核通常学习边缘、纹理等基础特征
    • 深层卷积核则学习更复杂的语义特征
    • 通过观察卷积核,我们可以分析模型是否在有效学习
  2. 特征图可视化的意义特征图展示了输入经过卷积核处理后的结果:

    • 每个特征图对应一种特定特征的响应强度
    • 亮的区域表示该位置对特定特征有强响应
    • 帮助我们理解模型关注输入图像的哪些区域

四、两者的区别

  • 卷积核:是模型里的 “过滤器”,它的作用是去输入里找特定的模式(比如边缘、纹理等)。图里每个小方块的亮度,反映的是卷积核自身参数的大小。
  • 特征图:是输入经过卷积核处理后得到的结果。它的亮度反映的是 “输入中哪些部分和卷积核匹配”—— 输入里和卷积核模式匹配的地方,在特征图里会更亮。

简单类比:卷积核是 “模板”,特征图是 “模板在输入里的匹配结果”。


五、使用方法

  1. 对于自定义模型:

    • 将卷积层的 weight 参数传入visualize_kernels函数
    • 将卷积层输出的特征图传入visualize_feature_maps函数
  2. 对于预训练模型:

    • 代码中提供了 VGG16 模型的示例
    • 通过注册钩子 (hook) 可以获取中间层的特征图
    • 可以修改代码查看不同层的卷积核和特征图
  3. 如果你运行代码中的demo_custom_conv(),会看到:

              1. 4 个随机生成的 3x3 卷积核(模型的 "眼睛")

              2.这些卷积核处理随机输入后生成的 4 个特征图(模型 "看到的内容")

                这样你就能直观感受到:卷积核的形状不同,"看到" 的特征也完全不同。

                如果想观察真实图片的处理效果,可以尝试demo_pretrained_model(),它会展示 VGG16 这个经典模型如何 "看世界"。


六、核心参数含义解析

卷积核  特征图可视化核心参数含义解析:《卷积核特征图可视化核心参数含义解析》资源-CSDN下载


通过这个工具:

  • 能直观理解 "卷积" 不是黑箱,而是有规律的特征提取
  • 能观察到不同卷积核的分工(谁负责边缘、谁负责纹理)
  • 能看到输入图像如何一步步被转化为抽象特征
http://www.dtcms.com/a/423940.html

相关文章:

  • 【mysql】Mybatisplus BINARY {0} LIKE CONCAT(‘%‘, {1}, ‘%‘)写这句话是什么意思
  • 开发避坑指南(59):Vue3中高效删除数组元素的方法
  • wordpress建站要用模板吗wordpress搜索筛选
  • 安卓 WPS Office v18.21.0 国际版
  • 衡阳网站推广优化公司行业网站开发运营方案
  • 临海房产中介网站如何制作网站平台管理
  • 做网站多少人建e室内设计网官网平面图
  • git mere 错误后的回滚处理
  • Java开发入门(一)--- JDK与环境变量配置
  • 最好的营销型网站建设公司报电子商务(网站建设与运营)
  • 从0到1制作一个go语言游戏服务器(二)web服务搭建
  • 网站使用流程图昆明网站建设天锐科技
  • (uniapp)基于vue3父子组件间传递参数与方法
  • 铁岭开原网站建设高中课程免费教学网站
  • 高校网站群建设方案网站建设目录结构设计
  • 静态网站源码野花韩国视频在线观看免费高清
  • Windows下NVM保姆级指南:安装、切换版本、指定路径+淘宝镜像配置,一次搞定!
  • 杭州营销型网站建设杭州租车网站建设
  • 网站开发基础知识网站开发怎么连接sqlserver
  • 基于AC6366C做AI语音鼠标
  • 刘诗雯现身TCL品牌活动,雷鸟34Q9显示器同台竞技
  • 东莞百域网站建设公司手机网站开发屏幕尺寸一般是多少
  • 理财经理如何提高职场技能实现晋升
  • 【碎片化学习】SpringBoot中的自动配置(Auto Configuration)
  • PC16550 FIFO接收方式研究
  • 做基金的网站哪个好用什么程序做资讯类网站
  • 图书馆网站建设申请国外做仿牌网站
  • make, makefile, cmake, qmake 有何区别?
  • vite如何处理项目中的资源
  • 文网文网站建设wordpress只显示首页