当前位置: 首页 > news >正文

简单的 PyTorch 示例,可视化和解释 weight decay 的作用

场景:拟合一个简单的正弦函数

我们将训练一个小的神经网络去拟合一个正弦曲线(带噪声),并比较 使用和不使用 weight decay 的效果

1. 准备数据

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np# 构造数据(sin 函数 + 噪声)
x = torch.linspace(0, 2 * np.pi, 100).unsqueeze(1)
y = torch.sin(x) + 0.1 * torch.randn_like(x)  # 添加噪声

2. 定义模型(一个简单的 MLP)

class SimpleNet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(1, 64),nn.Tanh(),nn.Linear(64, 1))def forward(self, x):return self.net(x)

3. 分别训练两个模型(一个使用 weight decay,一个不使用)

def train(weight_decay_value):model = SimpleNet()optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=weight_decay_value)loss_fn = nn.MSELoss()for epoch in range(500):#告诉模型:进入训练模式(比如启用 dropout、batchnorm 的训练行为)。model.train()#前向传播(预测)y_pred = model(x)# 计算损失loss = loss_fn(y_pred, y)# 反向传播准备#清空之前累积的梯度,防止干扰新一轮的计算。optimizer.zero_grad()#反向传播(求梯度)#根据损失反向传播,计算参数的梯度。loss.backward()# 根据梯度和优化器策略(含 weight decay)更新模型参数。optimizer.step()return model

4. 可视化结果对比

model_no_decay = train(0.0)
model_with_decay = train(0.01)# 绘图
plt.figure(figsize=(10, 5))
plt.scatter(x.numpy(), y.numpy(), label='Data', color='gray', alpha=0.5)with torch.no_grad():y_pred1 = model_no_decay(x)y_pred2 = model_with_decay(x)plt.plot(x.numpy(), y_pred1.numpy(), label='No Weight Decay', color='blue')
plt.plot(x.numpy(), y_pred2.numpy(), label='With Weight Decay', color='red')
plt.legend()
plt.title("Effect of Weight Decay on Fitting Sin Curve")
plt.show()

运行结果如下

5. 解释结果

模型效果
❌ 无 weight decay拟合得非常贴近噪声,容易过拟合,曲线不光滑
✅ 有 weight decay曲线更平滑,不那么贴合噪声,更接近真实函数

✅ 原因:

  • 没有 weight decay:网络自由度太高,学到了很多噪声特征;

  • 加了 weight decay:对权重值大小施加了惩罚,网络“更保守”,只学到主要趋势。


🔚 总结

  • weight_decay 本质上就是 L2 正则化,防止参数变得太大;

  • 它可以 减少过拟合、提高泛化能力

  • 在 LoRA 微调、预训练、分类任务中都非常重要;

  • 推荐值通常在 0.01 左右,需调参。

http://www.dtcms.com/a/263570.html

相关文章:

  • 云上攻防—Docker安全容器逃逸特权模式危险挂载
  • 【C++】简单学——模板初阶
  • tauri v2 开源项目学习(一)
  • PSQL 处理 BLOB 类型数据问题
  • 华为云Flexus+DeepSeek征文 | ​​华为云ModelArts Studio大模型与企业AI会议纪要场景的对接方案
  • 数据库事务全面指南:概念、语法、机制与最佳实践
  • C++ 快速回顾(五)
  • 【冷知识】Spring Boot 配置文件外置
  • SpringBoot -- 自动配置原理
  • Bessel位势方程求解步骤
  • STL简介+string模拟实现
  • 「Java案例」计算矩形面积
  • 大数据(3)-Hive
  • 【算法】动态规划:1137. 第 N 个泰波那契数
  • 初等变换 线性代数
  • C++ STL之string类
  • Windows11系统中安装docker并配置docker镜像到pycharm中
  • EA自动交易完全指南:从策略设计到实盘部署
  • SpringBoot 启动入口深度解析:main方法执行全流程
  • Android Telephony 网络状态中的 NAS 信息
  • 反射,枚举和lambda表达式
  • 《垒球百科》老年俱乐部有哪些项目·垒球1号位
  • 从零到一通过Web技术开发一个五子棋
  • 【MySQL基础】MySQL索引全面解析:从原理到实践
  • 人形机器人_双足行走动力学:MIT机器人跌落自恢复算法及应用
  • 使用Verilog设计模块输出中位数,尽可能较少资源使用
  • 本周股指想法
  • 产品背景知识——API、SDK、Library、Framework、Protocol
  • 10.【C语言学习笔记】指针(二)
  • Python 数据分析与机器学习入门 (八):用 Scikit-Learn 跑通第一个机器学习模型