当前位置: 首页 > news >正文

【深度学习-pytorch篇】4. 正则化方法(Regularization Techniques)

正则化方法(Regularization Techniques)

1. 目标

  • 理解什么是过拟合及其影响
  • 掌握常见正则化技术:L2 正则化、Dropout、Batch Normalization、Early Stopping
  • 能够使用 PyTorch 编程实现这些正则化方法并进行比较分析

2. 数据构造与任务设定

本实验是一个带噪声的回归任务,目标函数为 y = x + N ( 0 , σ 2 ) y = x + \mathcal{N}(0, \sigma^2) y=x+N(0,σ2)。使用均匀分布采样输入 x ∈ [ − 1 , 1 ] x \in [-1, 1] x[1,1]

import numpy as np
import torch
import torch.utils.data as DataN_SAMPLES = 20
NOISE_RATE = 0.4train_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
train_y = train_x + np.random.normal(0, NOISE_RATE, train_x.shape)validate_x = np.linspace(-1, 1, N_SAMPLES // 2)[:, np.newaxis]
validate_y = validate_x + np.random.normal(0, NOISE_RATE, validate_x.shape)test_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
test_y = test_x + np.random.normal(0, NOISE_RATE, test_x.shape)# 转换为 Tensor
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
validate_x = torch.tensor(validate_x, dtype=torch.float32)
validate_y = torch.tensor(validate_y, dtype=torch.float32)
test_x = torch.tensor(test_x, dtype=torch.float32)
test_y = torch.tensor(test_y, dtype=torch.float32)train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)

3. 模型定义

3.1 原始 MLP(无正则化)

import torch.nn as nn
import torch.nn.init as initclass FC_Classifier(nn.Module):def __init__(self, input_dim=1, hidden_dim=100, output_dim=1):super().__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.activation(self.fc1(x))return self.fc2(x)

3.2 Dropout MLP

class DropoutMLP(nn.Module):def __init__(self, dropout_rate=0.5):super().__init__()self.fc1 = nn.Linear(1, 100)self.dropout = nn.Dropout(dropout_rate)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.dropout(self.fc1(x))x = self.activation(x)return self.fc2(x)

3.3 Batch Normalization MLP

class BNMLP(nn.Module):def __init__(self):super().__init__()self.bn_input = nn.BatchNorm1d(1)self.fc1 = nn.Linear(1, 100)self.bn_hidden = nn.BatchNorm1d(100)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()def forward(self, x):x = self.bn_input(x)x = self.fc1(x)x = self.bn_hidden(x)x = self.activation(x)return self.fc2(x)

4. Early Stopping 策略

当验证集误差连续若干轮无提升时,提前停止训练,避免过拟合。

max_patience = 5
patience = 0
best_val_loss = float("inf")
is_early_stop = False

5. RMSNorm 实现与讲解

5.1 原理说明

RMSNorm 是一种替代 LayerNorm 的轻量化归一化方法:

  • 不减均值
  • 仅用激活值的均方根进行归一化
  • 不依赖 batch 维度

数学公式:

RMS ( x ) = 1 n ∑ i = 1 n x i 2 \text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2} RMS(x)=n1i=1nxi2

RMSNorm ( x ) = x RMS ( x ) + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x) + \epsilon} \cdot \gamma RMSNorm(x)=RMS(x)+ϵxγ

其中 γ \gamma γ 为可学习参数, ϵ \epsilon ϵ 是一个很小的数避免除以 0。

5.2 代码实现

class RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.eps = epsdef forward(self, x):rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)return self.weight * x / rms

5.3 与其他归一化对比

方法是否减均值是否除方差是否依赖 batch
BatchNorm
LayerNorm
RMSNorm是 (仅 RMS)

6. 实验建议

  • 尝试不同的 Dropout 比例(如 0.1 / 0.3 / 0.5)并观察效果;
  • 对比是否每层都加 BatchNorm 是否更优;
  • 比较 L2 正则项中 weight decay 的不同取值;
  • 使用 RMSNorm 替代 LayerNorm 做对比实验。

相关文章:

  • ParakeetTDT0.6BV2,语音识别ASR,极速转录, 高精度英文转录,标点支持(附整合包)
  • 常用算法模板函数(Python)
  • 用Python玩转人工智能——手搓图像分类模型
  • 【PhysUnits】13 改进减法(sub.rs)
  • 【加密算法】
  • 从“被动养老”到“主动健康管理”:平台如何重构代际关系?
  • Odoo 条码功能全面深度解析(VIP15万字版)
  • LiveNVR :实现非国标流转国标流的全方位解决方案
  • 勾股数的性质和应用
  • 接地气的方式认识JVM(一)
  • 通过teamcity cloud创建你的一个build
  • 【C语言】详解 指针
  • Java开发之定时器学习
  • 欧拉角转为旋转矩阵
  • 二叉树的锯齿形层序遍历——灵活跳跃的层次结构解析
  • w~视觉~合集6
  • 自我觉察是成长的第一步,如何构建内心的平静
  • 【线程与进程区别】
  • Spring AI框架快速入门
  • 华为OD机试真题——最佳的出牌方法(2025A卷:200分)Java/python/JavaScript/C/C++/GO最佳实现
  • 宿州银行网站建设/什么关键词能搜到资源
  • 企业网站内使用了哪些网络营销方式/乔拓云网站注册
  • 昆明电商网站建设/怎样策划一个营销型网站
  • 深圳建设管理中心网站首页/上海网络营销上海网络推广
  • 成都建站程序/宁波网络营销公司
  • 如何开网站赚钱/seo工具不包括