当前位置: 首页 > news >正文

深度学习:Dropout 技术

在深度学习的领域,模型的复杂性和灵活性使得它们在训练数据上表现出色,但同时也容易导致过拟合。过拟合是指模型在训练数据上表现良好,但在未见数据上表现不佳。

为了解决这个问题,研究人员提出了多种正则化技术,其中 Dropout 是一种非常有效且广泛使用的方法。

在训练过程中,Dropout 会以一定的概率随机选择一部分神经元,并将它们的输出设置为零
这意味着在每个训练批次中,模型的结构会有所不同,从而减少了模型对特定神经元的依赖

1、Dropout 的原理

在训练阶段,Dropout 会以设定的概率(通常在 0.2 到 0.5 之间)随机选择一部分神经元并将其输出设置为零。例如,如果 Dropout 概率为 0.2,那么在每个训练步骤中,约 20% 的神经元将被丢弃。

PyTorch 的 Dropout 层在训练模式和评估模式下的行为是不同的:

  • 训练模式:在训练模式下,Dropout 会随机丢弃一部分神经元,并将保留的神经元的输出乘以1p\frac{1}{p}p1(即11−dropout_prob\frac{1}{1 - \text{dropout\_prob}}1dropout_prob1),以确保在训练过程中输出的期望值与未使用 Dropout 时相同。

  • 评估模式:在评估模式下,Dropout 不会丢弃任何神经元,所有神经元的输出都被使用,输出值保持不变。

Dropout 的实现大致如下(简化版):

def forward(self, x):if self.training:  # 训练模式mask = (torch.rand(x.size()) < self.p).float()x = x * mask / self.p  # 应用掩码并缩放if not self.training:  # 评估模式x = xreturn x

二、 Dropout 训练和测试输出对比

在使用 Dropout 的过程中,尽管训练和测试阶段的输出值可能不完全相同,但它们的期望值是相同的

以下是一个简单的示例,展示了如何使用 PyTorch 实现 Dropout,并计算训练和测试阶段的输出期望值。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 定义简单的线性模型
class SimpleLinearModel(nn.Module):def __init__(self):super(SimpleLinearModel, self).__init__()self.linear = nn.Linear(10, 1)  # 输入特征数量为 10,输出特征数量为 1self.dropout = nn.Dropout(p=0.2)  # Dropout 层,保留概率为 0.8def forward(self, x):x = self.linear(x)  # 线性层x = self.dropout(x)  # 应用 Dropoutreturn x# 创建模型实例
model = SimpleLinearModel()# 模拟输入数据
input_data = torch.randn(100, 10)  # 100 个样本,每个样本 10 个特征# 训练阶段
model.train()
train_outputs = []
for _ in range(100):  # 进行 100 次训练output = model(input_data)train_outputs.append(output.mean().item())  # 记录输出的期望值# 测试阶段
model.eval()  # 切换到评估模式
test_outputs = []
for _ in range(100):  # 进行 100 次测试output = model(input_data)test_outputs.append(output.mean().item())  # 记录输出的期望值# 绘制训练和测试输出的对比
plt.figure(figsize=(12, 6))
plt.plot(train_outputs, label='Training Outputs', color='blue', marker='o', markersize=4, linestyle='-')
plt.plot(test_outputs, label='Testing Outputs', color='orange', marker='x', markersize=4, linestyle='--')
plt.title('Comparison of Training and Testing Outputs')
plt.xlabel('Iteration')
plt.ylabel('Output Mean Value')
plt.legend()
plt.grid()
plt.show()
# 计算期望值
train_expectation = sum(train_outputs) / len(train_outputs)
test_expectation = sum(test_outputs) / len(test_outputs)print(f'Training Expectation: {train_expectation:.4f}')
print(f'Testing Expectation: {test_expectation:.4f}')

在上述代码中,我们定义了一个简单的线性模型,并在训练和测试阶段分别记录了输出的期望值。由于 Dropout 的存在,训练阶段的输出会受到随机丢弃的影响,但通过缩放,保留的神经元的输出期望值与测试阶段的输出保持一致。
在这里插入图片描述

三、为什么 Dropout 可以解决过拟合?

  1. 随机失活神经元:Dropout 的核心思想是随机丢弃一部分神经元的输出。在每个训练步骤中,Dropout 以一定的概率(通常在 0.2 到 0.5 之间)随机选择神经元并将其输出设置为零。这种随机性使得模型在每次训练时都在不同的子网络上进行学习,从而减少了对特定神经元的依赖。

  2. 减少共适应性:在没有 Dropout 的情况下,神经元之间可能会形成强烈的共适应性,即某些神经元的输出依赖于其他神经元的输出。这种共适应性可能导致模型在训练数据上过拟合。通过随机丢弃神经元,Dropout 促使模型学习到更为独立和通用的特征,从而降低了共适应性。

  3. 增强模型的鲁棒性:Dropout 引入随机性,增强了模型的鲁棒性。模型在训练过程中必须适应不同的子网络,这使得它能够更好地处理未见数据,从而在面对新数据时表现得更加稳定和可靠。

  4. 期望输出一致性:Dropout 在训练阶段通过缩放保留神经元的输出,确保训练和测试阶段的期望输出一致。这种一致性使得模型在训练时能够学习到有效的特征,而在测试时能够利用这些特征进行准确的预测。


文章转载自:

http://HjNxsw7p.rcttz.cn
http://Z6oz8BXF.rcttz.cn
http://MOkY1jhP.rcttz.cn
http://EACoW4Dd.rcttz.cn
http://bNTthgop.rcttz.cn
http://R5W4GZRZ.rcttz.cn
http://1ri9Udun.rcttz.cn
http://TT9UPwm2.rcttz.cn
http://DQMomVBS.rcttz.cn
http://L77sqTze.rcttz.cn
http://Zh9oPC2v.rcttz.cn
http://uqT10Cs8.rcttz.cn
http://gXPgRU7D.rcttz.cn
http://nKbnk4Kb.rcttz.cn
http://my2RuYcr.rcttz.cn
http://72sghPIt.rcttz.cn
http://d53VHicS.rcttz.cn
http://ydwONZ09.rcttz.cn
http://rl4TFNTr.rcttz.cn
http://1w2WRfns.rcttz.cn
http://e8TPleUI.rcttz.cn
http://N0F4iXII.rcttz.cn
http://jFTz47ga.rcttz.cn
http://kN9pJuTS.rcttz.cn
http://Fd2Jengt.rcttz.cn
http://ywteoWFy.rcttz.cn
http://f96cfzT6.rcttz.cn
http://LD6BF6Vs.rcttz.cn
http://ESZ61hO9.rcttz.cn
http://nmhHzmQi.rcttz.cn
http://www.dtcms.com/a/368091.html

相关文章:

  • Linux 磁盘扩容及分区相关操作实践
  • 【前端】使用Vercel部署前端项目,api转发到后端服务器
  • 【ARDUINO】ESP8266的AT指令返回内容集合
  • Netty从0到1系列之Netty整体架构、入门程序
  • 实战记录:H3C路由器IS-IS Level-1邻居建立与路由发布
  • iOS 抓包工具有哪些?常见问题与对应解决方案
  • 【Linux】网络安全管理:SELinux 和 防火墙联合使用 | Redhat
  • Boost搜索引擎 网络库与前端(4)
  • 服务器硬盘“Unconfigured Bad“状态解决方案
  • 警惕!你和ChatGPT的对话,可能正在制造分布式妄想
  • 中天互联:AI 重塑制造,解锁智能生产新效能​
  • 如何制造一个AI Agent:从“人工智障”到“人工智能”的奇幻漂流
  • 鼓励员工提出建议,激发参与感——制造企业软件应用升级的密钥
  • 2025世界职校技能大赛总决赛争夺赛汽车制造与维修赛道比赛资讯
  • LeetCode 240: 搜索二维矩阵 II - 算法详解(秒懂系列
  • [特殊字符] AI时代依然不可或缺:精通后端开发的10个GitHub宝藏仓库
  • 【MFC】对话框节点属性:Condition(条件)
  • 【MFC 小白日记】对话框编辑器里“原型图像”到底要不要勾?3 分钟看懂!
  • 【为YOLOv11Seg添加MFC界面】详细指南
  • VBA 中使用 ADODB 操作 SQLite 插入中文乱码问题
  • Python 实现 Markdown 与 Word 高保真互转(含批量转换)
  • 如何在 C# 中将文本转换为 Word 以及将 Word 转换为文本
  • 电商企业如何选择高性价比仓储系统?专业定制+独立部署,源码交付无忧
  • Mysql:由逗号分隔的id组成的varchar联表替换成对应文字
  • Windows环境下实现GitLab与Gitee仓库代码提交隔离
  • PXM的JAVA并发编程学习总结
  • Cursor Pair Programming:在前端项目里用 AI 快速迭代 UI 组件
  • java面试中经常会问到的集合问题有哪些(基础版)
  • 23种设计模式——桥接模式 (Bridge Pattern)详解
  • AI日报 - 2025年09月05日