当前位置: 首页 > news >正文

AI笔记 - 模型调试 - 调试方式

模型调试方式

  • 基础信息
  • 打印模型信息
  • 计算参数量和计算量
    • 过滤原则
    • profile方法
    • get_model_complexity_info方法
    • FlopCountAnalysis方法

基础信息

# 打印执行的设备数量:device_count:1
print(f"device_count:{torch.cuda.device_count()}")# 打印当前网络执行的设备信息:device: cuda:0
print(f"device: {next(self.net.parameters()).device}")  # 应该输出: cuda:0

打印模型信息

#操作	    代码示例
#-----------------------------------------------------
#遍历所有模块	for name, module in model.named_modules():
#-----------------------------------------------------
#打印参数详情	module.named_parameters()
#-----------------------------------------------------
#打印缓冲区	module.named_buffers()
#-----------------------------------------------------
#过滤特定层	isinstance(module, nn.Conv2d)
#-----------------------------------------------------
#统计计算量	profile(module, inputs=(input,))
#-----------------------------------------------------import torchvision.models as modelsmodel = models.resnet50(weights=None).cuda()  # 不加载预训练权重以减少下载时间
input = torch.randn(1, 3, 224, 224).cuda()
for name, p in model.named_parameters():
print(f"params name:{name}, shape:{p.shape}, device:{p.device}")
print(f"dtype: {p.dtype}, 是否需要梯度:{p.requires_grad}")#params name:conv1.weight, shape:torch.Size([64, 3, 7, 7]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
#params name:bn1.weight, shape:torch.Size([64]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
#params name:bn1.bias, shape:torch.Size([64]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
...for name, module in model.named_modules():print(f"模块名称:{name}, 模块类型:{type(module).__name__}")# 打印可训练参数(weight/bias)for param_name, param in module.named_parameters():print(f"  - 参数:{param_name} | 形状:{param.shape} | 设备:{param.device} | 需梯度:{param.requires_grad} | 数据类型:{param.dtype}")# 打印缓冲区(如BatchNorm的running_mean)for buffer_name, buffer in module.named_buffers():print(f"  - 缓冲区: {buffer_name} | 形状: {buffer.shape} | 设备: {buffer.device}")# 模块名称:, 模块类型:ResNet
#  - 参数:conv1.weight | 形状:torch.Size([64, 3, 7, 7]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.weight | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.bias | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32......
#  - 缓冲区: bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.running_var | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.num_batches_tracked | 形状: torch.Size([]) | 设备: cuda:0
#  - 缓冲区: layer1.0.bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0......
# 模块名称:layer1.0, 模块类型:Bottleneck
#  - 参数:conv1.weight | 形状:torch.Size([64, 64, 1, 1]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.weight | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32......
#  - 缓冲区: bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.running_var | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.num_batches_tracked | 形状: torch.Size([]) | 设备: cuda:0......
#模块名称:layer1.0.conv1, 模块类型:Conv2d
#  - 参数:weight | 形状:torch.Size([64, 64, 1, 1]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#模块名称:layer1.0.bn1, 模块类型:BatchNorm2d
......

计算参数量和计算量

过滤原则

在计算模型计算量(FLOPs)时,过滤掉 BatchNorm2d、Sequential 和 Bottleneck 等非关键层是常见的需求

层类型是否过滤原因
BatchNorm2d✅ 过滤计算量极小(仅逐通道缩放),可忽略
Sequential✅ 过滤容器层(实际计算在子层)
Bottleneck✅ 过滤复合层(计算量已包含在子层中)
Conv2d/Linear❌ 保留核心计算层
ReLU/Pooling⚠️ 可选通常忽略(或单独统计)

profile方法

from thop import profilemodel = models.resnet50(weights=None).cuda()  # 不加载预训练权重以减少下载时间
input = torch.randn(1, 3, 224, 224).cuda()
flops, params = profile(model, inputs=(input,))
print(f"FLOPs: {flops / 1e9:.2f} G")  # 输出: ~4.11 GFLOPs
print(f"Params: {params / 1e6:.2f} M")  # 输出: ~25.56 Million

get_model_complexity_info方法

from ptflops import get_model_complexity_infomacs, params = get_model_complexity_info(self.net,(3, 1280, 1280),  # (channels, height, width)as_strings=True,print_per_layer_stat=True,  # 打印每层计算量verbose=True,
)
print(f"FLOPs: {macs}")
print(f"Params: {params}")# Warning: module IntermediateLayerGetter,FPN,SSH,ClassHead,BboxHead,LandmarkHead,RetinaFace,DataParallel is treated as a zero-op.
# DataParallel(
#  426.61 k, 100.000% Params, 4.07 GMac, 99.943% MACs, 
#  (module): RetinaFace(
#    426.61 k, 100.000% Params, 4.07 GMac, 99.943% MACs, 
#    (body): IntermediateLayerGetter(
#      213.07 k, 49.946% Params, 1.45 GMac, 35.733% MACs, 
#      (stage1): Sequential(
#        10.13 k, 2.374% Params, 642.25 MMac, 15.774% MACs, 
#        (0): Sequential(
#          232, 0.054% Params, 98.3 MMac, 2.414% MACs, 
#          (0): Conv2d(216, 0.051% Params, 88.47 MMac, 2.173% MACs, 3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#          (1): BatchNorm2d(16, 0.004% Params, 6.55 MMac, 0.161% MACs, 8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#          (2): LeakyReLU(0, 0.000% Params, 3.28 MMac, 0.080% MACs, negative_slope=0.1, inplace=True)
#        )
#        (1): Sequential()
#  ......
# )
# FLOPs: 4.07 GMac
# Params: 426.61 k

FlopCountAnalysis方法

from fvcore.nn import FlopCountAnalysisflops = FlopCountAnalysis(self.net, image)
flops = flops.unsupported_ops_warnings(False)  # 忽略不支持的操作警告# 计算总 FLOPs
print(flops.by_module())  # 打印每个模块的 FLOPs
total_flops = flops.total()
print(f"Total FLOPs: {total_flops / 1e9:.2f} G")# 打印每一层的 FLOPs,返回字典 {模块名: FLOPs}
print(flops.by_module())# 打印按模块分组的 FLOPs
print(flops.by_module_and_operator())  # 更详细的统计

相关文章:

  • 基于ssm的商城系统(全套)
  • 为 MCP Server 提供 Auth 认证,及 Django 实现示例
  • 20250528-C#知识:枚举
  • 学习路之Nginx--不同域名的反向代理
  • MySQL MVCC(多版本并发控制)详解
  • 力扣热题100之二叉树的中序遍历
  • 力扣HOT100之回溯:51. N 皇后
  • 学习python day10
  • 【白雪讲堂】多模态技术:统一认知的优化器
  • [CISCN 2021初赛]glass
  • OpenLayers 加载网格信息
  • Redis 5 种基础数据结构?
  • LiveNVR 直播流拉转:Onvif/RTSP/RTMP/FLV/HLS 支持海康宇视天地 SDK 接入-视频广场页面集成与视频播放说明
  • 《清晰思考》
  • 实验设计与分析(第6版,Montgomery)第4章随机化区组,拉丁方, 及有关设计4.5节思考题4.1~4.4 R语言解题
  • 本地(Linux)编译 MySQL 源码
  • 三、zookeeper 常用shell命令
  • 触控精灵 ADB运行模式填写电脑端IP教程
  • Linux基础 -- Linux 启动调试之深入理解 `initcall_debug` 与 `ignore_loglevel`
  • 从零到一选择AI自动化平台:深度解析n8n、Dify与Coze
  • .red域名做网站好不好/网页制作软件推荐
  • 抖音广告推广/做seo必须有网站吗
  • 哪个网站可以做创意短视频/百度指数工具
  • 网站做的好的公司有/种子搜索神器网页版
  • 做网站一般长宽多少/宁波seo网络推广外包报价
  • 做体育网站/福州网站建设方案外包