当前位置: 首页 > news >正文

动手学深度学习(pytorch版):第四章节—多层感知机(7、8)数值稳定性和模型初始化

到目前为止,实现的每个模型都是根据某个预先指定的分布来初始化模型的参数。 这些初始化方案的选择可以与非线性激活函数的选择有趣的结合在一起。 我们选择哪个函数以及如何初始化参数可以决定优化算法收敛的速度有多快。 糟糕选择可能会导致我们在训练时遇到梯度爆炸或梯度消失。 

1. 梯度消失和梯度爆炸

不稳定梯度带来的风险不止在于数值表示; 不稳定梯度也威胁到我们优化算法的稳定性。 我们可能面临一些问题。 要么是梯度爆炸(gradient exploding)问题: 参数更新过大,破坏了模型的稳定收敛; 要么是梯度消失(gradient vanishing)问题: 参数更新过小,在每次更新时几乎不会移动,导致模型无法学习。

1.1. 梯度消失

sigmoid函数为什么会导致梯度消失。

%matplotlib inline
import torch
from d2l import torch as d2lx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
y = torch.sigmoid(x)
y.backward(torch.ones_like(x))d2l.plot(x.detach().numpy(), [y.detach().numpy(), x.grad.numpy()],legend=['sigmoid', 'gradient'], figsize=(4.5, 2.5))

正如上图,当sigmoid函数的输入很大或是很小时,它的梯度都会消失。 此外,当反向传播通过许多层时,除非在刚刚好的地方, 这些地方sigmoid函数的输入接近于零,否则整个乘积的梯度可能会消失。 当我们的网络有很多层时,除非我们很小心,否则在某一层可能会切断梯度。 事实上,这个问题曾经困扰着深度网络的训练。

1.2. 梯度爆炸

相反,梯度爆炸可能同样令人烦恼。 为了更好地说明这一点,生成100个高斯随机矩阵,并将它们与某个初始矩阵相乘。 对于选择的尺度(方差),矩阵乘积发生爆炸。 当这种情况是由于深度网络的初始化所导致时,我们没有机会让梯度下降优化器收敛。

M = torch.normal(0, 1, size=(4,4))
print('一个矩阵 \n',M)
for i in range(100):M = torch.mm(M,torch.normal(0, 1, size=(4, 4)))print('乘以100个矩阵后\n', M)

1.3. 打破对称性

神经网络设计中的另一个问题是其参数化所固有的对称性。 假设有一个简单的多层感知机,它有一个隐藏层和两个隐藏单元。 在这种情况下,可以对第一层的权重进行重排列, 并且同样对输出层的权重进行重排列,可以获得相同的函数。 第一个隐藏单元与第二个隐藏单元没有什么特别的区别。 换句话说,在每一层的隐藏单元之间具有排列对称性。

2. 参数初始化

解决(或至少减轻)上述问题的一种方法是进行参数初始化, 优化期间的注意和适当的正则化也可以进一步提高稳定性。

2.1. 默认初始化

我们使用正态分布来初始化权重值。如果我们不指定初始化方法, 框架将使用默认的随机初始化方法,对于中等难度的问题,这种方法通常很有效。

2.2. Xavier初始化

某些没有非线性的全连接层输出(例如,隐藏变量)的尺度分布。

权重都是从同一分布中独立抽取的。 此外,让我们假设该分布具有零均值和方差。 请注意,这并不意味着分布必须是高斯的,只是均值和方差需要存在。 现在,让我们假设层的输入也具有零均值和方差, 并且它们独立于并且彼此独立。

保持方差不变的一种方法是设置。 现在考虑反向传播过程,我们面临着类似的问题,尽管梯度是从更靠近输出的层传播的。 使用与前向传播相同的推断,我们可以看到,除非, 否则梯度的方差可能会增大,其中是该层的输出的数量。 这使得我们进退两难:我们不可能同时满足这两个条件。

2.3. 额外阅读

上面的推理仅仅触及了现代参数初始化方法的皮毛。 深度学习框架通常实现十几种不同的启发式方法。 此外,参数初始化一直是深度学习基础研究的热点领域。 其中包括专门用于参数绑定(共享)、超分辨率、序列模型和其他情况的启发式算法。

http://www.dtcms.com/a/339310.html

相关文章:

  • 《算法导论》第 31 章 - 数论算法
  • 个人介绍CSDNmjhcsp
  • Kubernetes集群安装部署--flannel
  • Vue 2 项目中快速集成 Jest 单元测试(超详细教程)
  • 云计算学习100天-第23天
  • github 上传代码步骤
  • 【Python】新手入门:python模块是什么?python模块有什么作用?什么是python包?
  • Day13_【DataFrame数据组合merge连接】【案例】
  • 嵌入式开发学习———Linux环境下网络编程学习(三)
  • 第5.5节:awk算术运算
  • RabbitMQ:交换机(Exchange)
  • LeetCode-17day:贪心算法
  • 95、23种设计模式之建造者模式(4/23)
  • 大模型 + 垂直场景:搜索/推荐/营销/客服领域开发新范式与技术实践
  • 抓取手机游戏相关数据
  • 细化的 Spring Boot 和 Spring Framework 版本对应关系
  • c++计算器(简陋版)
  • 【全面推导】策略梯度算法:公式、偏差方差与进化
  • 差分(附带例题题解)
  • 深度学习 --- 基于ResNet50的野外可食用鲜花分类项目代码
  • 基于单片机身体健康监测/身体参数测量/心率血氧血压
  • 接口性能测试工具 - JMeter
  • . keepalived+haproxy
  • Ubuntu22.04安装docker最新教程,包含安装自动脚本
  • 【QT入门到晋级】进程间通信(IPC)-socket(包含详细分析及性能优化)
  • Day08 Go语言学习
  • C#/.NET/.NET Core技术前沿周刊 | 第 50 期(2025年8.11-8.17)
  • es7.x es的高亮与solr高亮查询的对比对比说明
  • 彻底清理旧版本 Docker 的痕迹
  • pytorch学习笔记-模型训练、利用GPU加速训练(两种方法)、使用模型完成任务