当前位置: 首页 > news >正文

PyTorch中“原地”赋值的思考

在开发一个PyTorch模块时,遇到了一个诡异的现象,将他描述出来就是下面这样:

f[..., :p_index - 1] = f[..., 1:p_index]

这个操作将f张量的部分数值进行左移,我在模型训练的时候还能正常跑,但是当我将模型部署到项目中时,这行代码报错了!

Traceback (most recent call last):File "<input>", line 1, in <module>
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

这个PyTorch报错是因为在执行操作时,输入张量和目标张量共享了同一块内存地址(存在内存重叠),导致PyTorch无法安全地完成原地(in-place)操作。

既然这样的话为什么在模型训练的时候不会这样呢?后面我仔细研究了一下午,发现了下面的原因:


当我们模型在训练阶段中,f的形状通常是(B,F)的形式存在的,而在部署的时候,作推理时数据通常是(1,F)的形式,所以会出现下面的情况:

# 创建高维张量(3维)
f_3d = torch.randn(16, 1, 25)
slice_3d = f_3d[..., 1:24]  # 源切片print("高维张量切片是否连续:")
print(slice_3d.is_contiguous())  # 输出 False# 创建一维张量对比
f_1d = torch.randn(1, 1, 25)
slice_1d = f_1d[..., 1:24]print("\n一维张量切片是否连续:")
print(slice_1d.is_contiguous())  # 输出 True

可以看到,当张量是维度大于1时,其在内存中是非连续存储的,而张量维度为1时,其在内存中是连续存储的。对于非连续张量,PyTorch会在赋值时隐式创建临时副本,避免内存覆盖。因此在进行原地赋值时不会报错。

最后,为了加强代码的鲁棒性,我在所有涉及这部分操作的代码后面加上了clone()函数。

f[..., :p_index - 1] = f[..., 1:p_index].clone()

相关文章:

  • GPU虚拟化实现(六)
  • 线段树原理和代码详解
  • 课题推荐——通信信号处理中的非线性系统状态估计(如信号跟踪、相位恢复等场景),使用无迹卡尔曼滤波(UKF)的非线性滤波算法,MATLAB实现
  • 【C++重载操作符与转换】输入和输出操作符
  • 深入解析Session与Cookie:从HTTP无状态到现代会话管理
  • 【kafka系列】消费者组
  • 使用Nexus搭建远程maven仓库
  • MySQL零基础入门:Ubuntu环境安装与操作精解
  • AWK 文本分析工具核心总结
  • HashMap,高效 哈希
  • Python生活手册-文件二进制:从快递柜到生鲜冷链的数据保鲜术
  • 业务流程BPM能力框架体系及华为中兴流程变革案例P83(83页PPT)(文末有下载方式)
  • python拜占庭将军
  • 【大模型实战篇】华为信创环境采用vllm部署QwQ-32B模型
  • 部署.NET6.0 Web API项目到Docker
  • 基于开源AI智能名片链动2+1模式S2B2C商城小程序的电商直播流量转化路径研究
  • 【Linux】Makefile
  • AI大模型基础设施:主流的几款开源AI大语言模型的本地部署成本
  • kafka学习笔记(四、生产者(客户端)深入研究(二)——消费者协调器与_consumer_offsets剖析)
  • windows系统搭建自己的ftp服务器,保姆级教程(用户验证+无验证)
  • 泽连斯基:美乌矿产协议将提交乌拉达批准
  • 国际著名学者Charles M. Lieber全职受聘清华深圳国际研究生院
  • 龚惠民已任江西省司法厅党组书记
  • 新片|《我仍在此》定档5月,《新·驯龙高手》同步北美上映
  • 原国家有色金属工业局副局长黄春萼逝世,享年86岁
  • 节前A股持续震荡,“五一”假期持股还是持币过节胜率更高?