当前位置: 首页 > news >正文

pytorch基本运算-分离计算

【1】引言

前序学习进程中,已经对pytor求导有了基本认识,知道requires_grad_(True)和backward()是求导必备的声明。
但有一种特殊情况,如果变量z=f(xy)z=f(xy)z=f(xy),但同时y=f(x)y=f(x)y=f(x),也就是变量yyy是变量xxx的函数,变量zzz同时是变量yyy和变量xxx的函数,此时可以有表达式:
z=f(xy)=f(x(f(x)))z=f(xy)=f(x(f(x)))z=f(xy)=f(x(f(x)))如果我们就是想获得yyy计算出来后,zzz关于xxx的导数,,此时就要考虑如何将yyy单独分流出来,不要让梯度经过yyy追溯到xxx
这就是本次要学习的重点:分离计算。

【2】detach()函数

分离计算要使用detach()函数,用一个自定义的新变量名比如ttt来获取原变量yyy的值,此时把zzz的表达式改写成:
z=f(xy)=f(xt))z=f(xy)=f(xt))z=f(xy)=f(xt))
此时计算dzdc\frac{dz}{dc}dcdz就不会通过ttt传递到xxxttt只是一个和yyy相等的常数。
为此找一个例子做测试。

# 引入模块
import torch
from torch.autograd import backward# 定义初始张量
x=torch.arange(5.0)
# 声明要对x求导
z=x.requires_grad_(True)
print('z=',z)
# 乘积定义
y=x*x
z=y*x
# 梯度计算
z.sum().backward()
print('z=',z)
# 此时没有单独提取y
g1=x.grad
print('g1=',g1)
# 梯度清零
x.grad.zero_()
# 使用t分离y
t=y.detach()
# 重新定义函数
z=t*x
# 计算梯度
z.sum().backward()
print('z=',z)
# g2是用t分离y后获得的梯度
g2=x.grad
print('g2=',g2)
# 理论上,根据z=t*x,如果t是一个常数,梯度结果就是t
print('t=',t)

这个代码块先计算了不分离yyy的梯度g1g1g1,然后计算了分离yyy的梯度g2g2g2,证明了分离后确实梯度计算不会再由ttt追溯到xxx,实现了保持yyy为常数的运算目标。
计算结果为:

在这里插入图片描述

【3】总结

学习了pytorch分离计算导数的基本概念。


文章转载自:

http://tDx03oxr.wLfxn.cn
http://Nur1EOD6.wLfxn.cn
http://CMTon8wi.wLfxn.cn
http://A392ghNe.wLfxn.cn
http://QUbTxHLO.wLfxn.cn
http://qmB3evLe.wLfxn.cn
http://cI69ZKD3.wLfxn.cn
http://OPChuCBf.wLfxn.cn
http://uTNc3m7g.wLfxn.cn
http://5eiHA0Ia.wLfxn.cn
http://37nATbQm.wLfxn.cn
http://x4D70u53.wLfxn.cn
http://r2z3JJHd.wLfxn.cn
http://wdp4BcdJ.wLfxn.cn
http://xf58bvNo.wLfxn.cn
http://MfqSzxbO.wLfxn.cn
http://FDqWklXB.wLfxn.cn
http://rJTlGh6A.wLfxn.cn
http://NSy1upEl.wLfxn.cn
http://cPrKRuUh.wLfxn.cn
http://r5NbtW0B.wLfxn.cn
http://MtSoHK4c.wLfxn.cn
http://inHuwDGG.wLfxn.cn
http://VPrSUcSh.wLfxn.cn
http://I830AjRJ.wLfxn.cn
http://NxtHdRKF.wLfxn.cn
http://AKaS9N7v.wLfxn.cn
http://jZPyCha0.wLfxn.cn
http://rlc443F9.wLfxn.cn
http://IFbKXrlc.wLfxn.cn
http://www.dtcms.com/a/374781.html

相关文章:

  • 基于容器化云原生的 MySQL 及中间件高可用自动化集群项目
  • “图观”端渲染场景编辑器
  • 构建分布式京东商品数据采集系统:基于 API 的微服务实现方案
  • HTML5点击转圈圈 | 深入了解HTML5技术中的动态效果与用户交互设计
  • springboot rabbitmq 延时队列消息确认收货订单已完成
  • CString(MFC/ATL 框架)和 QString(Qt 框架)
  • Sklearn(机器学习)实战:鸢尾花数据集处理技巧
  • 工具框架:Scikit-learn、Pandas、NumPy预测鸢尾花的种类
  • AI GEO 优化能否快速提升网站在搜索引擎的排名?​
  • nvm和nrm的详细安装配置,从卸载nodejs到安装NVM管理nodejs版本,以及安装nrm管理npm版本
  • 对口型视频怎么制作?从脚本到成片的全流程解析
  • 从“能说话”到“会做事”:AI Agent如何重构日常工作流?
  • 洛谷 P1249 最大乘积-普及/提高-
  • 小红书获取笔记详情API接口会返回哪些数据?
  • JAVA Spring Boot maven导入使用本地SDK(jar包)
  • Linux/UNIX系统编程手册笔记:SOCKET
  • F5和Nginx的区别
  • 9.9网编简单TCP,UDP的实现day2
  • Day39 SQLite数据库操作与HTML核心API及页面构建
  • Vue3 与 AntV X6 节点传参、自动布局及边颜色控制教程
  • 线程与进程的区别
  • RAC概念笔记
  • 如何将视频从安卓手机传输到电脑?
  • Day04_苍穹外卖——套餐管理(实战)
  • ElementUI 组件概览
  • fifo之读写指针
  • 【第三次全国土壤普查】一键制备土壤三普环境变量23项遥感植被指数神器
  • Java反射机制详解
  • PDF文件中的广告二维码图片该怎么批量删除
  • 记一次 .NET 某中医药附属医院门诊系统 崩溃分析