当前位置: 首页 > news >正文

pytorch checkpointing

是一种在训练深度神经网络时通过增加计算代价来换取显存优化的技术。它的核心思想是:在反向传播过程中动态重新计算中间激活值(activations),而不是保存所有中间结果。这对于显存受限的场景(如训练大型模型)非常有用。

直接上代码:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint# 1. 定义一个简单的 FFN 模型
class SimpleFFN(nn.Module):def __init__(self, input_dim=128, hidden_dim=256, output_dim=10):super().__init__()self.linear1 = nn.Linear(input_dim, hidden_dim)self.linear2 = nn.Linear(hidden_dim, hidden_dim)self.linear3 = nn.Linear(hidden_dim, output_dim)self.relu = nn.ReLU()def forward(self, x):# 2. 定义一个自定义的前向传播函数(用于 checkpoint)def custom_forward(x):x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.relu(x)x = self.linear3(x)return x# 3. 使用 checkpoint 包装前向传播return checkpoint(custom_forward, x)# 4. 初始化模型和数据
model = SimpleFFN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()# 模拟输入数据
input_data = torch.randn(64, 128)  # batch_size=64, input_dim=128
target = torch.randn(64, 10)       # 模拟目标输出# 5. 前向传播、损失计算和反向传播
output = model(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
  • 在反向传播时,custom_forward 会被重新调用,从输入 x 重新计算中间激活值,从而节省显存。
  • 显存占用:仅保存 linear3 的输出和 x,中间激活值在反向传播时动态计算。
  • 需要多次前向计算激活值,训练速度可能变慢

相关文章:

  • AD创建元件符号
  • 蓝桥杯 17. 通电
  • 嵌入式硬件设计全解析:从架构到实战
  • C/C++滑动窗口算法深度解析与实战指南
  • 89. 格雷编码
  • 时间同步服务
  • 签名去背景图像处理实例
  • Qwen3的“混合推理”是如何实现的
  • 黑马点评大总结
  • 在Excel中轻松处理嵌套JSON数据:json-to-excel插件使用指南
  • Vue3核心语法速成
  • 慢sql处理流程和常见案例
  • 20250505下载VLC for Android
  • git上常用的12个月份对应的英语单词以及月份英语缩写形式
  • 矩阵快速幂 快速求解递推公式
  • 二重指针和二维数组
  • 力扣119题解
  • 机场围界报警系统的研究与应用
  • 深入理解 CSS Flex 布局:代码实例解析
  • WMS仓库管理系统:Java+Vue,含源码及文档,集成仓储全流程管控,实现库存精准、作业高效、数据透明
  • 俄军击落多架企图攻击莫斯科的无人机
  • 贵州黔西市游船倾覆事故发生后,多家保险公司紧急响应
  • 戴紫薇评《不像说母语者》丨后殖民语境下的母语追寻
  • 环球马术冠军赛圆满落幕,是马术盛宴更是中国马产业强大引擎
  • 当一群杜克土木工程毕业生在三四十年后怀念大学的历史课……
  • 消息人士称以色列政府初步同意扩大对加沙军事行动