当前位置: 首页 > news >正文

Python 整理3种查看神经网络结构的方法

1. 网络结构代码

import torch
import torch.nn as nn# 定义Actor-Critic模型
class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim):super(ActorCritic, self).__init__()self.actor = nn.Sequential(# 全连接层,输入维度为 state_dim,输出维度为 256nn.Linear(state_dim, 64),nn.ReLU(),nn.Linear(64, action_dim),# Softmax 函数,将输出转换为概率分布,dim=-1 表示在最后一个维度上应用 Softmaxnn.Softmax(dim=-1))self.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.ReLU(),nn.Linear(64, 1))def forward(self, state):policy = self.actor(state)value = self.critic(state)return policy, value# 参数设置
state_dim = 1
action_dim = 2model = ActorCritic(state_dim, action_dim)

2. 查看结构

2.1 直接打印模型

print(model)

输出:

ActorCritic((actor): Sequential((0): Linear(in_features=1, out_features=64, bias=True)(1): ReLU()(2): Linear(in_features=64, out_features=2, bias=True)(3): Softmax(dim=-1))(critic): Sequential((0): Linear(in_features=1, out_features=64, bias=True)(1): ReLU()(2): Linear(in_features=64, out_features=1, bias=True))
)

2.2 可视化网络结构(需要安装 torchviz 包)

安装 torchsummary 包:

$ pip install torchsummary

python 代码:

from torchviz import make_dot# 创建一个虚拟输入
x = torch.randn(1, state_dim)
# 生成计算图
dot = make_dot(model(x), params=dict(model.named_parameters()))
dot.render("actor_critic_model", format="png")  # 保存为PNG图片

输出 actor_critic_model

digraph {graph [size="12,12"]node [align=left fontname=monospace fontsize=10 height=0.2 ranksep=0.1 shape=box style=filled]140281544075344 [label="(1, 2)" fillcolor=darkolivegreen1]140281544213744 [label=SoftmaxBackward0]140281544213840 -> 140281544213744140281544213840 [label=AddmmBackward0]140281544213600 -> 140281544213840140285722327344 [label="actor.2.bias(2)" fillcolor=lightblue]140285722327344 -> 140281544213600140281544213600 [label=AccumulateGrad]140281544214032 -> 140281544213840140281544214032 [label=ReluBackward0]140281544213984 -> 140281544214032140281544213984 [label=AddmmBackward0]140281544214176 -> 140281544213984140285722327024 [label="actor.0.bias(64)" fillcolor=lightblue]140285722327024 -> 140281544214176140281544214176 [label=AccumulateGrad]140281544214224 -> 140281544213984140281544214224 [label=TBackward0]140281543934832 -> 140281544214224140285722327264 [label="actor.0.weight(64, 1)" fillcolor=lightblue]140285722327264 -> 140281543934832140281543934832 [label=AccumulateGrad]140281544213648 -> 140281544213840140281544213648 [label=TBackward0]140281544214080 -> 140281544213648140285722327184 [label="actor.2.weight(2, 64)" fillcolor=lightblue]140285722327184 -> 140281544214080140281544214080 [label=AccumulateGrad]140281544213744 -> 140281544075344140285722328704 [label="(1, 1)" fillcolor=darkolivegreen1]140281544213888 [label=AddmmBackward0]140281544214368 -> 140281544213888140285722328064 [label="critic.2.bias(1)" fillcolor=lightblue]140285722328064 -> 140281544214368140281544214368 [label=AccumulateGrad]140281544214128 -> 140281544213888140281544214128 [label=ReluBackward0]140281544214464 -> 140281544214128140281544214464 [label=AddmmBackward0]140281544214512 -> 140281544214464140285722327424 [label="critic.0.bias(64)" fillcolor=lightblue]140285722327424 -> 140281544214512140281544214512 [label=AccumulateGrad]140281544214560 -> 140281544214464140281544214560 [label=TBackward0]140281544214704 -> 140281544214560140285722327504 [label="critic.0.weight(64, 1)" fillcolor=lightblue]140285722327504 -> 140281544214704140281544214704 [label=AccumulateGrad]140281544213696 -> 140281544213888140281544213696 [label=TBackward0]140281544214272 -> 140281544213696140285722327584 [label="critic.2.weight(1, 64)" fillcolor=lightblue]140285722327584 -> 140281544214272140281544214272 [label=AccumulateGrad]140281544213888 -> 140285722328704
}

输出模型图片:
在这里插入图片描述

2.3 使用 summary 方法(需要安装 torchsummary 包)

安装 torchsummary 包:

pip install torchsummary

代码:

from torchsummary import summarydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)
summary(model, input_size=(state_dim,))#查看模型参数
print("查看模型参数:")
for name, param in model.named_parameters():print(f"Layer: {name} | Size: {param.size()} | Values: {param[:2]}...")

输出:

cuda:0
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                   [-1, 64]             128ReLU-2                   [-1, 64]               0Linear-3                    [-1, 2]             130Softmax-4                    [-1, 2]               0Linear-5                   [-1, 64]             128ReLU-6                   [-1, 64]               0Linear-7                    [-1, 1]              65
================================================================
Total params: 451
Trainable params: 451
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
查看模型参数:
Layer: actor.0.weight | Size: torch.Size([64, 1]) | Values: tensor([[ 0.7747],[-0.0440]], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: actor.0.bias | Size: torch.Size([64]) | Values: tensor([ 0.5995, -0.2155], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: actor.2.weight | Size: torch.Size([2, 64]) | Values: tensor([[ 0.0373,  0.0851,  0.1000,  0.1060,  0.0387,  0.0479,  0.0127,  0.0696,0.0388,  0.0033,  0.1173, -0.1195, -0.0830,  0.0186,  0.0063, -0.0863,-0.0353,  0.0782, -0.0558,  0.0011, -0.0533,  0.1241,  0.0120, -0.0906,-0.0551, -0.0673, -0.1070,  0.0402, -0.0662,  0.0596, -0.0811,  0.0457,0.0349,  0.0564, -0.0155, -0.0404,  0.0843, -0.0978,  0.0459,  0.1097,-0.0858,  0.0736, -0.0067, -0.0756, -0.0363, -0.0525, -0.0426, -0.1087,-0.0611,  0.0420, -0.1038,  0.0402,  0.0065, -0.1217, -0.0467,  0.0383,-0.0217,  0.0283,  0.0800,  0.0228,  0.0415, -0.0473, -0.0199, -0.0436],[-0.1118, -0.0806, -0.0700, -0.0224,  0.0335, -0.0087,  0.0265, -0.1196,-0.0907, -0.0360,  0.0621, -0.0471, -0.0939, -0.0912, -0.1061,  0.1051,-0.0592, -0.0757,  0.0758, -0.1082, -0.0317,  0.1208, -0.0279, -0.0693,0.0920, -0.0318, -0.0476,  0.0236, -0.0761,  0.0591,  0.0862, -0.0712,0.0156, -0.1073,  0.1133,  0.0039, -0.0191,  0.0605, -0.0686, -0.1202,0.0962,  0.0581,  0.1145,  0.0741, -0.0993, -0.0987,  0.0939,  0.1006,0.0773, -0.0756, -0.1096,  0.0156, -0.0599,  0.0857,  0.1005, -0.0618,0.0474,  0.0066, -0.0531, -0.0479,  0.1136,  0.0356,  0.1169, -0.0023]],device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: actor.2.bias | Size: torch.Size([2]) | Values: tensor([-0.0039,  0.0937], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.0.weight | Size: torch.Size([64, 1]) | Values: tensor([[0.5799],[0.0473]], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.0.bias | Size: torch.Size([64]) | Values: tensor([ 0.6507, -0.6974], device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.2.weight | Size: torch.Size([1, 64]) | Values: tensor([[ 0.0738, -0.0370, -0.1010, -0.0333, -0.0595, -0.0172,  0.0928,  0.0815,0.1221, -0.0842,  0.0511,  0.0452, -0.0386, -0.0503, -0.0964,  0.0370,-0.0341, -0.0693, -0.0845,  0.0424, -0.0491, -0.0439, -0.0443,  0.0203,0.0960, -0.1178, -0.0836, -0.0144, -0.0576, -0.0851,  0.0461,  0.1160,0.0120,  0.1180,  0.0255,  0.1047, -0.0398,  0.0786,  0.1143,  0.0806,0.1125,  0.0267,  0.0534, -0.0318,  0.1125, -0.0727,  0.1169,  0.0120,-0.0178, -0.0845,  0.0069,  0.0194,  0.1188,  0.0481,  0.1077, -0.0840,0.1013,  0.0586, -0.0857, -0.0974, -0.0630,  0.0359, -0.0080, -0.0926]],device='cuda:0', grad_fn=<SliceBackward0>)...
Layer: critic.2.bias | Size: torch.Size([1]) | Values: tensor([0.0621], device='cuda:0', grad_fn=<SliceBackward0>)...

相关文章:

  • 虚幻引擎作者采访
  • 什么是原码、反码与补码?
  • 2025流感疫苗指南+卫健委诊疗方案|高危人群防护+并发症处理 慢性肾脏病饮食指南2025卫健委版|低盐低磷食谱+中医调理+PDF 网盘下载 pdf下载
  • 牛客1018逆序数-归并排序
  • 金融的本质是智融、融资的实质是融智、投资的关键是投智,颠覆传统金融学的物质资本中心论,构建了以智力资本为核心的新范式
  • PyTorch 张量与自动微分操作
  • 全球化电商平台Azure云架构设计
  • 期末代码Python
  • iptables的基本选项及概念
  • 串 Part 1
  • 数据链路层(MAC 地址)
  • Gemini 解释蓝图节点的提示词
  • STC单片机与淘晶驰串口屏通讯例程之04【密码登录与修改】
  • 有哪些场景不适合使用Java反射机制
  • 基于C++实现的深度学习(cnn/svm)分类器Demo
  • H3C无线控制器自动信道功率调整典型配置实验
  • 数据结构小扫尾——栈
  • JAVA:使用 Maven Assembly 创建自定义打包的技术指南
  • Kubernetes(k8s)学习笔记(七)--KubeSphere 最小化安装
  • 音频感知动画新纪元:Sonic让你的作品更生动
  • 新闻1+1丨多地政府食堂开放 “舌尖上的服务”,反映出怎样的理念转变?
  • 玉渊谭天丨是自保还是自残?八个恶果透视美国征收100%电影关税
  • 杨德龙:取得长期投资胜利法宝,是像巴菲特一样践行价值投资
  • 五一假期前三日多景区客流刷新纪录,演艺、古镇、山水都很火
  • 猎金,游戏,诚不我欺
  • 党旗下的青春|赵天益:少年确定志向,把最好的时光奉献给戏剧事业