当前位置: 首页 > news >正文

动手学深度学习(pytorch版):第五章节—深度学习计算(2)参数管理

目录

1. 参数访问

1.1. 目标参数

1.2. 一次性访问所有参数

1.3. 从嵌套块收集参数

2. 参数初始化

2.1. 内置初始化

2.2. 自定义初始化

2.3. 参数绑定


在选择了架构并设置了超参数后,就进入了训练阶段。 目标是找到使损失函数最小化的模型参数值。 经过训练后,将需要使用这些参数来做出未来的预测。 此外,有时希望提取参数,以便在其他环境中复用它们, 将模型保存下来,以便它可以在其他软件中执行, 或者为了获得科学的理解而进行检查。

之前的介绍中,我们只依靠深度学习框架来完成训练的工作, 而忽略了操作参数的具体细节。介绍以下内容:

  • 访问参数,用于调试、诊断和可视化;

  • 参数初始化;

  • 在不同模型组件间共享参数。

首先看一下具有单隐藏层的多层感知机。

import torch
from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

1. 参数访问

我们从已有模型中访问参数。 当通过Sequential类定义模型时, 我们可以通过索引来访问模型的任意层。 这就像模型是一个列表一样,每层的参数都在其属性中。 如下所示,我们可以检查第二个全连接层的参数。

print(net[2].state_dict())

首先,这个全连接层包含两个参数,分别是该层的权重和偏置。 两者都存储为单精度浮点数(float32)。 注意,参数名称允许唯一标识每个参数,即使在包含数百个层的网络中也是如此。

1.1. 目标参数

注意,每个参数都表示为参数类的一个实例。 要对参数执行任何操作,首先需要访问底层的数值。 有几种方法可以做到这一点。有些比较简单,而另一些则比较通用。

下面的代码从第二个全连接层(即第三个神经网络层)提取偏置, 提取后返回的是一个参数类实例,并进一步访问该参数的值。

print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

参数是复合的对象,包含值、梯度和额外信息。 这就是我们需要显式参数值的原因。 除了值之外,我们还可以访问每个参数的梯度。 在上面这个网络中,由于我们还没有调用反向传播,所以参数的梯度处于初始状态。

net[2].weight.grad == None

1.2. 一次性访问所有参数

当需要对所有参数执行操作时,逐个访问它们可能会很麻烦。 当处理更复杂的块时,情况可能会变得特别复杂, 因为需要递归整个树来提取每个子块的参数。 下面,将通过演示来比较访问第一个全连接层的参数和访问所有层。

print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])

这提供了另一种访问网络参数的方式

net.state_dict()['2.bias'].data

1.3. 从嵌套块收集参数

首先定义一个生成块的函数(可以说是“块工厂”),然后将这些块组合到更大的块中。

def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())def block2():net = nn.Sequential()for i in range(4):# 在这里嵌套net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)

设计了网络后,它是如何工作的。

print(rgnet)

因为层是分层嵌套的,所以也可以像通过嵌套列表索引一样访问它们。 下面,访问第一个主要的块中、第二个子块的第一层的偏置项。

rgnet[0][1][0].bias.data

2. 参数初始化

知道了如何访问参数后,现在看看如何正确地初始化参数。  深度学习框架提供默认随机初始化, 也允许我们创建自定义初始化方法, 满足我们通过其他规则实现初始化权重。

默认情况下,PyTorch会根据一个范围均匀地初始化权重和偏置矩阵, 这个范围是根据输入和输出维度计算出的。 PyTorch的nn.init模块提供了多种预置初始化方法。

2.1. 内置初始化

首先调用内置的初始化器。 下面的代码将所有权重参数初始化为标准差为0.01的高斯随机变量, 且将偏置参数设置为0。

def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

还可以将所有参数初始化为给定的常数,比如初始化为1。

def init_constant(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 1)nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

还可以对某些块应用不同的初始化方法。 例如,下面使用Xavier初始化方法初始化第一个神经网络层, 然后将第三个神经网络层初始化为常量值42

def init_xavier(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)
def init_42(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 42)net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

2.2. 自定义初始化

同样,我们实现了一个my_init函数来应用到net

def my_init(m):if type(m) == nn.Linear:print("Init", *[(name, param.shape)for name, param in m.named_parameters()][0])nn.init.uniform_(m.weight, -10, 10)m.weight.data *= m.weight.data.abs() >= 5net.apply(my_init)
net[0].weight[:2]

注意,始终可以直接设置参数。

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

2.3. 参数绑定

可以定义一个稠密层,然后使用它的参数来设置另一个层的参数。

# 我们需要给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared, nn.ReLU(),shared, nn.ReLU(),nn.Linear(8, 1))
net(X)
# 检查参数是否相同
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
# 确保它们实际上是同一个对象,而不只是有相同的值
print(net[2].weight.data[0] == net[4].weight.data[0])

这个例子表明第三个和第五个神经网络层的参数是绑定的。 它们不仅值相等,而且由相同的张量表示。

http://www.dtcms.com/a/342653.html

相关文章:

  • 进程和进程调度
  • Rclone入门对象存储云到云迁移
  • 我从零开始学微积分(2)- 函数与图形
  • YOLO --- YOLOv3以及YOLOv4模型详解
  • Redis Hash数据类型深度解析:从命令、原理到实战场景
  • IPSEC安全基础后篇
  • 易焓仪器安全帽耐熔融金属飞溅性能测试仪:飞溅场景适配与精准检测
  • 力扣 30 天 JavaScript 挑战 第37天 第九题笔记 知识点: 剩余参数,拓展运算符
  • 智慧农业温室大棚远程监控物联网系统解决方案
  • CRaxsRat v7.4:网络安全视角下的深度解析与防护建议
  • AECS(国标ECALL GB 45672-2025)
  • 5G视频终端详解 无人机图传 无线图传 便携式5G单兵图传
  • 汇总图片拖进ps中 photoshop同时打开几个文件夹
  • 【论文阅读 | TCSVT 2025 | CFMW:面向恶劣天气下鲁棒目标检测的跨模态融合Mamba模型】
  • 深入理解Docker网络:从docker0到自定义网络
  • 设计简洁的Ansible:目前非常流行的开源配置管理和自动化工具
  • webrtc中win端音频---windows Core Audio
  • Mysql基础(②锁)
  • 想在手机上操作服务器?cpolar让WaveTerminal终端随身携带,效率倍增
  • 高并发短信网关平台建设方案概述
  • 打造医疗新质生产力
  • nodejs安装后 使用npm 只能在cmd 里使用 ,但是不能在poowershell使用,只能用npm.cmd
  • ES_多表关联
  • Linux 信号 (Signals)
  • 鱼眼相机去畸变的算法原理(一)
  • WEB服务器(静态/动态网站搭建)
  • 循环神经网络实战:用 LSTM 做中文情感分析(二)
  • Mokker AI:一键更换照片背景的AI神器
  • 鸿蒙生态开发全栈指南
  • mac的m3芯片安装mysql