当前位置: 首页 > news >正文

python第35天打卡

 

三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchinfo import summary
import numpy as np# 1. 定义示例模型
class CNN(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16, 32, kernel_size=3),nn.ReLU(),nn.MaxPool2d(2))self.fc_layers = nn.Sequential(nn.Linear(32*6*6, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):x = self.conv_layers(x)x = torch.flatten(x, 1)x = self.fc_layers(x)return xmodel = CNN()# 2. 模型摘要可视化 (使用torchinfo)
print("\n=== 模型结构摘要 ===")
summary(model, input_size=(1, 3, 32, 32),  # (batch, channels, height, width)col_names=["input_size", "output_size", "num_params", "kernel_size"],verbose=1
)# 3. 权重分布可视化
def plot_weight_distribution(model):plt.figure(figsize=(12, 6))# 收集所有权重all_weights = []for name, param in model.named_parameters():if 'weight' in name:flattened = param.detach().cpu().numpy().flatten()all_weights.extend(flattened)# 绘制直方图plt.hist(all_weights, bins=150, alpha=0.7, color='blue', edgecolor='black')plt.title('Model Weights Distribution')plt.xlabel('Weight Value')plt.ylabel('Frequency (log scale)')plt.yscale('log')plt.grid(True, alpha=0.3)plt.show()print("\n=== 权重分布图 ===")
plot_weight_distribution(model)# 4. 逐层权重可视化 (额外方法)
def plot_layer_weights(model):plt.figure(figsize=(15, 10))for i, (name, param) in enumerate(model.named_parameters()):if 'weight' in name:plt.subplot(3, 3, i+1)layer_weights = param.detach().cpu().numpy().flatten()plt.hist(layer_weights, bins=100, alpha=0.7)plt.title(f'{name} weight distribution')plt.grid(True, alpha=0.2)plt.tight_layout()plt.show()print("\n=== 逐层权重分布 ===")
plot_layer_weights(model)# 5. 权重矩阵可视化 (额外方法)
def visualize_weight_matrix(model):plt.figure(figsize=(15, 4))# 获取第一个卷积层的权重conv1_weight = model.conv_layers[0].weight.detach().cpu()# 归一化权重用于显示min_val = torch.min(conv1_weight)max_val = torch.max(conv1_weight)normalized_weights = (conv1_weight - min_val) / (max_val - min_val)# 绘制权重矩阵for i in range(16):  # 显示前16个卷积核plt.subplot(2, 8, i+1)plt.imshow(normalized_weights[i].permute(1, 2, 0))  # CHW -> HWCplt.axis('off')plt.title(f'Kernel {i+1}')plt.suptitle('First Conv Layer Kernels', fontsize=16)plt.tight_layout()plt.show()print("\n=== 卷积核可视化 ===")
visualize_weight_matrix(model)

 

@浙大疏锦行

相关文章:

  • 黑马程序员C++核心编程笔记--1 程序的内存模型
  • Android-kotlin协程学习总结
  • 瑞数6代jsvmp简单分析(天津电子税x局)
  • Linux云计算训练营笔记day17(Python)
  • 【b站计算机拓荒者】【2025】微信小程序开发教程 - chapter3 项目实践 - 3人脸识别采集统计人脸检测语音识别
  • 中间件redis 功能篇 过期淘汰策略和内存淘汰策略 力扣例题实现LRU
  • Unity屏幕适配——适配信息计算和安全区域适配
  • ElectronBot复刻-电路测试篇
  • PMO价值重构:从项目管理“交付机器”到“战略推手”
  • UE5 Niagara 如何让四元数进行旋转
  • Vite Vue3 配置 Composition API 自动导入与项目插件拆分
  • Ubuntu系统rsyslog日志突然占用磁盘空间超大怎么办?
  • mybatis-plus实现增删改查(新手理解版)
  • 弱光环境下如何手持相机拍摄静物:摄影曝光之等效曝光认知
  • 【深度学习-pytorch篇】2. Activation, 多层感知机与LLaMA中的MLP实现解析
  • 学习率及相关优化参数详解:驱动模型高效训练
  • DNS解析过程以及使用的协议名称
  • Pytorch中一些重要的经典操作和简单讲解
  • Fastmcp本地搭建 ,查询本地mysql,接入agent-cursor--详细流程
  • P2278 HNOI2003 操作系统
  • 做网站送优化/手机优化专家
  • 怎么能查到网站是哪个公司做的/百度联系方式人工客服
  • 贵阳高端网站建设/宁波网站建设优化企业
  • 收费网站怎么建立/南昌seo网站管理
  • 贵州 做企业网站的流程/公司做网站推广
  • 携程网建设网站的理由/竞价 推广