当前位置: 首页 > news >正文

pytorch retain_grad vs requires_grad

requires_grad大家都挺熟悉的,因此穿插在retain_grad的例子里进行捎带讲解就行。下面看一个代码片段:

import torch

# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)

# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3

# 继续计算,得到 z
z = y * 4

# 反向传播
z.backward()

# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

x.grad: tensor(12.)
y.grad: None
/tmp/ipykernel_219007/1060175670.py:17: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)
  print("y.grad:", y.grad)

警告的大致意思是:访问了非叶子节点的.grad属性,但非叶子节点的.grad属性并不会在反向传播的过程中被自动保存下来(这是为了节省内存,毕竟我们只需要计算那些手动设置.requires_gradTrue的张量的梯度,并进行梯度更新,对吧?)

因此,我们只需要添加一行代码y.retain_grad(),修改后的代码如下:

import torch

# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)

# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3
y.retain_grad()

# 继续计算,得到 z
z = y * 4

# 反向传播
z.backward()

# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

x.grad: tensor(12.)
y.grad: tensor(4.)

可以看到,现在非叶子节点y的梯度也在反向传播以后被正确保存了!

http://www.dtcms.com/a/63759.html

相关文章:

  • Python 融于ASP框架
  • snmp开发
  • C++内存模型和原子操作_第五章_《C++并发编程实战》笔记
  • java之uniapp实现门店地图
  • 前端 - vue - - import引入报错 require引入不报错 package.json中type的用法 延迟导入资源
  • xsync集群分发脚本开发指南
  • 使用AI一步一步实现若依前端(9)
  • 游戏引擎学习第150天
  • 洗鞋小程序(源码+文档+讲解+演示)
  • Spring(4)——响应相关
  • 如何测试 item_get_video 小红书接口返回数据的详细说明
  • 【统计至简】【古典概率模型】联合概率、边缘概率、条件概率、全概率
  • 【实战ES】实战 Elasticsearch:快速上手与深度实践-5.4.2用户画像聚合(Terms Aggregation + Cardinality)
  • SpringCloud——环境搭建
  • html css网页制作成品——糖果屋网页设计(4页)附源码
  • Java中数据库索引选择B+树而非红黑树的详细解析
  • 【前端拓展】Canvas性能革命!WebGPU + WebAssembly混合渲染方案深度解析
  • 【MySQL】增删改查进阶
  • 学习C2CRS Ⅲ (Response Generation Module)
  • 【编程向导】-JavaScript-基础语法-类型检测
  • 软考高级信息系统项目管理师笔记-第23章组织通用管理
  • redis趣味解读
  • SpringMVC工作原理
  • Python :Pandas
  • harmonyOS(鸿蒙)— 网络权限(解决app网络资源无法加载,图片无法显示)
  • 帕金森病如何 “偷走” 患者的正常生活?
  • gin框架
  • ORACLE EBS数据库RELINK方式搭建克隆环境
  • 黑色RGB是什么
  • C#实现AES-CBC加密工具类(含完整源码及使用教程)