当前位置: 首页 > news >正文

释放内存与加速推理:PyTorch的torch.no_grad()与torch.inference_mode()

文章目录

    • 0. 前言
    • 1. 为什么需要它们?理解计算图与梯度
    • 2. `torch.no_grad()`:经典解决方案
    • 3. `torch.inference_mode()`:更高效的继任者
    • 4. 关键区别与最佳实践
    • 5. 总结

0. 前言

📣按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在PyTorch模型中,从训练切换到评估/推理时,我们经常会看到model.eval()的身影。然而,还有一个(或者说两个)更为重要的"开关"能够显著提升推理性能并减少内存占用——它们就是torch.no_grad()和它的进化版torch.inference_mode(),本文将介绍它们的用法。

1. 为什么需要它们?理解计算图与梯度

我在前文 基于TorchViz详解计算图(附代码) 详细介绍过计算图。

简单来说,PyTorch的关键特性是自动求导。在训练过程中,每当对张量进行计算时,PyTorch会默默地构建一个计算图,跟踪所有操作以便通过反向传播计算梯度。

在训练时,这种跟踪是必要的。但在推理时,我们只需要前向传播的输出,不需要计算梯度。继续维护计算图只会:

  1. 消耗额外内存存储中间结果的梯度信息
  2. 增加计算开销为不必要的反向传播做准备

2. torch.no_grad():经典解决方案

下面我们直接通过实例来演示torch.no_grad()的作用,首先先定义一个极简的模型:

import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Sequential(nn.Linear(5000,50000),nn.ReLU(),nn.Linear(50000,5000))def forward(self,x):return self.linear(x)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel().to(device)x = torch.ones(1000,5000).to(device)
x.requires_grad =True

然后对比不使用torch.no_grad()和使用torch.no_grad()

print("====有梯度的计算====")
torch.cuda.empty_cache()
start_mem = torch.cuda.memory_allocated()
output_with_grad = model(x)
mem_with_grad = (torch.cuda.memory_allocated() - start_mem) / 1024 ** 2
print(f"是否有梯度:{output_with_grad.requires_grad}")
print(f"输出梯度函数{output_with_grad.grad_fn}")
print(f"有梯度的内存占用{mem_with_grad:.2f}MB")print("====使用torch.no_grad()====")
torch.cuda.empty_cache()
start_mem = torch.cuda.memory_allocated()
with torch.no_grad():output_with_no_grad = model(x)mem_with_no_grad = (torch.cuda.memory_allocated() - start_mem) / 1024 ** 2print(f"是否有梯度:{output_with_no_grad.requires_grad}")print(f"输出梯度函数{output_with_no_grad.grad_fn}")print(f"无梯度的内存占用{mem_with_no_grad:.2f}MB")print(f"使用no_grad()能节省{(1-mem_with_no_grad/mem_with_grad)*100:.2f}%的内存")

输出结果:

在这里插入图片描述

可以看到,在torch.no_grad()可以节省大量的内存!

3. torch.inference_mode():更高效的继任者

PyTorch 1.10引入了torch.inference_mode(),它比torch.no_grad()更加激进和高效,我们再看下torch.inference_mode()的内存占用情况

print("====使用torch.inference_mode====")
model.eval()
torch.cuda.empty_cache()
start_mem = torch.cuda.memory_allocated()
with torch.inference_mode():output_inference_mode = model(x)mem_inference_mode = (torch.cuda.memory_allocated() - start_mem) / 1024 ** 2print(f"是否有梯度:{output_inference_mode.requires_grad}")print(f"输出梯度函数{output_inference_mode.grad_fn}")print(f"inference_mode的内存占用{mem_inference_mode:.2f}MB")

输出结果:

在这里插入图片描述
在本次实验中,torch.inference_mode()torch.no_grad()显示出相同的内存占用(均为19.07MB),这主要是因为两者在核心优化机制上是一致的:它们都通过完全禁用梯度计算和计算图构建来实现主要的内存节省。在简单的单次前向传播场景下,内存消耗的主要来源是计算图中间结果的存储,而这一点两者都已完美解决。

然而,内存节省只是性能优化的一个维度torch.inference_mode()作为torch.no_grad()的进化版本,其真正优势在于更激进的内部优化策略——包括禁用版本计数器、减少运行时检查等,这些优化虽然对单次内存占用影响不大,但对计算效率的提升却至关重要。为了全面评估两者的性能差异,下面我们通过时间效率测试来揭示torch.inference_mode()在推理速度上的显著优势:

print("====让我们再对比下时间====")
import time
model.eval()# 测试 torch.no_grad() 性能
start_time = time.time()
with torch.no_grad():for _ in range(100):_ = model(x)
no_grad_time = time.time() - start_time# 测试 torch.inference_mode() 性能
start_time = time.time()
with torch.inference_mode():for _ in range(100):_ = model(x)
inference_time = time.time() - start_timeprint(f"torch.no_grad() 时间: {no_grad_time:.4f}s")
print(f"torch.inference_mode() 时间: {inference_time:.4f}s")
print(f"inference_mode 比 no_grad 快: {(1 - inference_time/no_grad_time)*100:.1f}%")

输出结果:
在这里插入图片描述

4. 关键区别与最佳实践

特性torch.no_grad()@torch.inference_mode()
梯度计算禁用禁用
计算图构建仍然构建,但不记录操作完全不构建
版本计数器仍然递增不递增
性能较好更优
内存使用较少更少
灵活性可在其中启用梯度不能在其中启用梯度

最佳实践建议:

  1. 训练代码中:使用model.eval() + torch.no_grad()
  2. 部署/生产环境中优先使用@torch.inference_mode()
  3. 需要调试或特殊情况:使用torch.no_grad()(更灵活)

5. 总结

torch.no_grad()torch.inference_mode()都是PyTorch推理优化的重要工具。理解它们的区别并正确使用,可以:

  • 显著减少内存占用
  • 提升推理速度
  • 让模型部署更加高效

记住这个简单的规则:在不需要梯度计算的任何地方,特别是模型推理时,都应该使用torch.inference_mode()

http://www.dtcms.com/a/524085.html

相关文章:

  • 论文笔记(九十六)VGGT: Visual Geometry Grounded Transformer
  • 城市基础设施安全运行监管平台
  • 网络 UDP 和 TCP / IP详细介绍
  • 数据结构(8)
  • [cpprestsdk] ~异步流处理(eg`basic_istream`、`basic_ostream`、`streambuf`) 底层
  • Linux 查找符合条件的文档
  • ​九小场所 / 乡镇监督防火 ——1 个平台管水源 / 隐患,整改率提 80%
  • 郑州做网站找绝唯科技地方类门户网站
  • 哪里可以做免费的物流网站国外室内设计案例网站
  • 【Linux系统】从零掌握make与Makefile:高效自动化构建项目的工具
  • ML:Supervised/Unsupervised
  • 开发网站多少钱北京 工业网站建设公司排名
  • 【后端开发面试题】
  • 【coze】基础概念与使用
  • Java 语法糖详解(含底层原理)
  • 企业网站介绍越南做企业网站
  • 免费建设电影网站宁波优化推广找哪家
  • JAVA1024 类 object类 包装类 享元模式 ;类继承 :interface ;构造方法
  • 树与二叉树的奥秘全解析
  • 《Python 正则表达式完全指南:从入门到精通》(AI版)
  • 【linux】vim快速清空整个文件
  • 基于单片机的故障检测自动保护智能防夹自动门设计及LCD状态显示系统
  • 2025妈妈杯大数据竞赛B题mathorcup:物流理赔风险识别及服务升级数学建模数模教学大学生辅导思路代码助攻
  • 对监控理解
  • 体育数据传输:HTTP API与WebSocket的核心差异
  • 货代如何做亚马逊和速卖通网站dedecms三合一网站源码
  • 燃烧学课程网站建设业之峰装饰官网
  • 做料理网站关键词怎么设置上海专业的网站建设
  • 英文 PDF 文档翻译成中文的优质应用
  • css实现拼图,响应不同屏幕宽度