当前位置: 首页 > news >正文

PyTorch参数管理详解:从访问到初始化与共享

本文通过实例代码讲解如何在PyTorch中管理神经网络参数,包括参数访问、多种初始化方法、自定义初始化以及参数绑定技术。所有代码可直接运行,适合深度学习初学者进阶学习。


1. 定义网络与参数访问

1.1 定义单隐藏层多层感知机

import torch
from torch import nn

# 定义单隐藏层多层感知机
net1 = nn.Sequential(
    nn.Linear(4, 8),  # 输入层4维,隐藏层8维
    nn.ReLU(),
    nn.Linear(8, 1)   # 输出层1维
)
x = torch.rand(2, 4)  # 随机生成2个4维输入向量
net1(x)                # 前向传播

1.2 访问网络参数

# 访问第二层(索引2)的参数(权重和偏置)
print(net1[2].state_dict())

# 查看参数类型、数据和梯度
print(type(net1[2].bias))    # 类型:Parameter
print(net1[2].bias)          # 参数值(含梯度信息)
print(net1[2].bias.data)     # 参数数据(张量)
print(net1[2].bias.grad)     # 梯度(未反向传播时为None)

1.3 批量访问参数

# 访问第一层的参数名称和形状
print(*[(name, param.shape) for name, param in net1[0].named_parameters()])

# 访问整个网络的参数
print(*[(name, param.shape) for name, param in net1.named_parameters()])

# 通过state_dict直接访问参数数据
print(net1.state_dict()['2.bias'].data)

2. 参数初始化方法

2.1 内置初始化

# 正态分布初始化权重,偏置置零
def init_normal(model):
    if isinstance(model, nn.Linear):
        nn.init.normal_(model.weight, mean=0, std=0.01)
        nn.init.zeros_(model.bias)

net1.apply(init_normal)
print(net1[0].weight.data[0], net1[0].bias.data[0])

# 常数初始化(权重为1,偏置为0)
def init_constant(model):
    if isinstance(model, nn.Linear):
        nn.init.constant_(model.weight, 1)
        nn.init.zeros_(model.bias)

net1.apply(init_constant)
print(net1[0].weight.data[0], net1[0].bias.data[0])

2.2 分层初始化

# 对第一层使用Xavier初始化,第二层使用常数42初始化
def xavier(model):
    if isinstance(model, nn.Linear):
        nn.init.xavier_uniform_(model.weight)

def init_42(model):
    if isinstance(model, nn.Linear):
        nn.init.constant_(model.weight, 42)

net1[0].apply(xavier)
net1[2].apply(init_42)
print(net1[0].weight.data[0])
print(net1[2].weight.data)

2.3 自定义初始化

# 自定义初始化:权重在[-10,10]均匀分布,并过滤绝对值小于5的值
def my_init(model):
    if isinstance(model, nn.Linear):
        print(f'init weight {model.weight.shape}')
        nn.init.uniform_(model.weight, -10, 10)
        model.weight.data *= (model.weight.abs() >= 5)

net1.apply(my_init)
print(net1[0].weight.data[:2])  # 显示前两行权重

3. 参数绑定与共享

3.1 直接修改参数

# 直接操作参数数据
net1[0].weight.data[:] += 1     # 所有权重+1
net1[0].weight.data[0, 0] = 42  # 修改特定位置权重
print(net1[0].weight.data[0])   # 输出第一行权重

3.2 参数共享

# 共享线性层参数
shared_layer = nn.Linear(8, 8)
net3 = nn.Sequential(
    nn.Linear(4, 8), nn.ReLU(),
    shared_layer, nn.ReLU(),     # 第2层
    shared_layer, nn.ReLU(),     # 第4层(共享参数)
    nn.Linear(8, 1)
)

# 验证参数共享
print(net3[2].weight.data[0] == net3[4].weight.data[0])  # 输出全True
net3[2].weight.data[0, 0] = 100
print(net3[2].weight.data[0] == net3[4].weight.data[0])  # 修改后仍为True

4. 嵌套网络结构

# 构建嵌套网络
def model1():
    return nn.Sequential(
        nn.Linear(4, 8), nn.ReLU(),
        nn.Linear(8, 4), nn.ReLU()
    )

def model2():
    net = nn.Sequential()
    for i in range(4):
        net.add_module(f'model{i}', model1())
    return net

rgnet = nn.Sequential(model2(), nn.Linear(4, 1))
print(rgnet)  # 打印网络结构

总结

本文演示了PyTorch中参数管理的核心操作,包括:

  • 通过state_dictnamed_parameters访问参数

  • 使用内置初始化方法(正态分布、常数、Xavier)

  • 自定义初始化逻辑

  • 参数的直接修改与共享

  • 复杂嵌套网络的定义

掌握这些技能可以更灵活地设计和优化神经网络模型。建议读者在实践中结合具体任务调整初始化策略,并注意参数共享时的梯度传播特性。


提示:以上代码需要在PyTorch环境中运行,建议使用Jupyter Notebook逐步调试以观察中间结果。

http://www.dtcms.com/a/112935.html

相关文章:

  • ARM架构与编程学习(四)(08_keil_gcc_Makefile)
  • 晶晨S905-S905L-S905LB_S905M2通刷_安卓6.0.1_16S极速开机_线刷固件包
  • 英语—四级CET4考试—蒙猜篇—匹配题
  • 测试:正交法设计测试用例
  • mysql数据库中getshell的方式总结
  • Java进阶-day06:反射、注解与动态代理深度解析
  • GPU显存占用高但利用率低的深度解析 (基于实际案例与技术文档)
  • python爬虫爬取淘宝热销(热门)台式电脑商品信息(课程设计;提供源码、使用说明文档及相关文档;售后可联系博主)
  • php8 命名参数使用教程
  • 跳跃连接(Skip Connection)与残差连接(Residual Connection)
  • 家庭路由器wifi设置LAN2LAN和LAN2WAN
  • STM32低功耗模式详解:睡眠、停机、待机模式原理与实践(下) | 零基础入门STM32第九十三步
  • 30信号和槽_带参数的信号槽(3)
  • [Linux]进程状态、僵尸进程处理回收、进程优先级 + 图例展示
  • kali——httrack
  • Tensorflow、Pytorch与Python、CUDA版本的对应关系(更新时间:2025年4月)
  • 6.1 python加载win32或者C#的dll的方法
  • 对应列表数据的分割和分组
  • 【瑞萨 RA-Eco-RA2E1-48PIN-V1.0 开发板测评】PWM
  • tkiner模块的初步学习
  • 冷门预警,英超006:埃弗顿VS阿森纳,阿森纳分心欧冠,太妃糖或有机可乘
  • TDengine 3.3.6.0 版本中非常实用的 Cols 函数
  • Vue.js设计与实现学习
  • 走进未来的交互世界:下一代HMI设计趋势解析
  • 第九章Python语言高阶加强-面向对象篇
  • 基于Python的微博数据采集
  • 架构及大数据-Zookeeper与Kafka的关系及使用依赖,二者需要同时使用吗?KRaft模式又是啥?
  • Linux常用命令详解:从基础到进阶
  • 基于Python+Flask的服装零售商城APP方案,用到了DeepSeek AI、个性化推荐和AR虚拟试衣功能
  • DCMM详解