当前位置: 首页 > news >正文

减少内存占用的两种方法|torch.no_grad和disable_torch_init

方法区别

在 PyTorch 中,disable_torch_inittorch.no_grad() 是两种完全不同的机制,它们的作用和目的不同,以下是它们的区别:

1. disable_torch_init

  • 作用disable_torch_init 通常用于某些特定的框架或库中,目的是禁用 PyTorch 的默认初始化逻辑。例如,在某些情况下,框架可能希望自定义模型参数的初始化方式,而不是使用 PyTorch 默认的初始化方法。
  • 显存优化原理:禁用默认初始化可以减少初始化过程中不必要的显存分配。例如,某些框架可能会在初始化时创建额外的临时张量或执行复杂的初始化逻辑,这些操作可能会占用显存。通过禁用这些默认初始化,可以节省这部分显存。
  • 使用场景:通常用于框架内部的优化,或者在某些特定的模型加载或训练准备阶段。

2. torch.no_grad()

  • 作用torch.no_grad() 上下文管理器或装饰器,用于禁用梯度计算。在 torch.no_grad() 的上下文内,所有张量操作都不会记录梯度信息,也不会构建计算图。
  • 显存优化原理:在默认情况下,PyTorch 会为每个需要梯度的张量(requires_grad=True)保存中间结果,以便在反向传播时计算梯度。这些中间结果会占用显存。通过禁用梯度计算,torch.no_grad() 可以避免这些中间结果的存储,从而显著减少显存占用。
  • 使用场景:主要用于模型的推理(inference)阶段,或者在不需要计算梯度的场景中。例如,在模型评估、数据预处理、特征提取等场景中,torch.no_grad() 是常用的优化手段。

3. 具体区别

特性disable_torch_inittorch.no_grad()
作用范围禁用模型参数的初始化逻辑禁用梯度计算和计算图构建
显存优化原理减少初始化过程中不必要的显存分配避免存储中间梯度和计算图,减少显存占用
使用场景模型加载或训练准备阶段模型推理、评估、数据预处理等
是否影响模型结构可能影响模型参数的初始化方式不影响模型结构,仅影响梯度计算
是否需要手动启用需要框架或用户显式调用可通过上下文管理器或装饰器显式启用

4. 总结

  • disable_torch_init 是一种针对模型初始化过程的优化机制,主要用于减少初始化阶段的显存占用。
  • torch.no_grad() 是一种禁用梯度计算的工具,主要用于推理阶段,通过避免计算图的构建和梯度存储来减少显存占用。

两者虽然都可以减少显存占用,但作用机制和使用场景完全不同。在实际应用中,torch.no_grad() 是更常用且更通用的显存优化手段,而 disable_torch_init 更多是框架内部的优化策略。

(常见)在评估前@torch.no_grad()

源代码:

class no_grad(_DecoratorContextManager):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.

    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    This context manager is thread local; it will not affect computation
    in other threads.

    Also functions as a decorator. (Make sure to instantiate with parenthesis.)

    .. note::
        No-grad is one of several mechanisms that can enable or
        disable gradients locally see :ref:`locally-disable-grad-doc` for
        more information on how they compare.

    .. note::
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
        If you want to disable forward AD for a computation, you can unpack
        your dual tensors.

    Example::
        >>> # xdoctest: +SKIP
        >>> x = torch.tensor([1.], requires_grad=True)
        >>> with torch.no_grad():
        ...     y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """
    def __init__(self) -> None:
        if not torch._jit_internal.is_scripting():
            super().__init__()
        self.prev = False

    def __enter__(self) -> None:
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

(放在评估函数里面)disable_torch_init()

源代码:

def disable_torch_init():
    """
    Disable the redundant torch default initialization to accelerate model creation.
    """
    import torch
    setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
    setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)

相关文章:

  • python-leetcode-最长公共子序列
  • 从二维随机变量到多维随机变量
  • P9420 [蓝桥杯 2023 国 B] 双子数--最高效的质数筛【埃拉托斯特尼筛法】
  • 【uniapp】在UniApp中实现持久化存储:安卓--导出数据为jsontxt
  • 【全干货】cocos简短demo制作-三消类游戏
  • 测试的BUG分析
  • 第二十九:5.7.【$subscribe】侦听数据
  • SpringBoot集成easy-captcha图片验证码框架
  • 《Somewhat Practical Fully Homomorphic Encryption》笔记 (BFV 源于这篇文章)
  • 前端Javascrip后端Net6前后分离文件上传案例(完整源代码)下载
  • 2025 最新版鸿蒙 HarmonyOS 开发工具安装使用指南
  • Go入门之文件
  • 华为AP 4050DN-HD的FIT AP模式改为FAT AP,家用FAT基本配置
  • 练习题:57
  • JDBC 进阶(未完结)
  • C# 确保程序只有一个实例运行
  • 如何确保邮件内容符合不同地区用户的文化习惯
  • 原子性(Atomicity)和一致性(Consistency)的区别?
  • 【备份】php项目处理跨域请求踩坑
  • 【JavaSE-2】数据类型与变量
  • 陈刚:推动良好政治生态和美好自然生态共生共优相得益彰
  • 遭车祸罹难的村医遇“身份”难题:镇卫生院否认劳动关系,家属上诉后二审将开庭
  • 政企共同发力:多地密集部署外贸企业抢抓90天政策窗口期
  • 网易一季度净利增长三成,丁磊:高度重视海外游戏市场
  • 获派驻6年后,中国驻厄瓜多尔大使陈国友即将离任
  • 齐白石精品在波士顿展出,“白石画屋”呈现水墨挥洒