PyTorch中“原地”赋值的思考
在开发一个PyTorch模块时,遇到了一个诡异的现象,将他描述出来就是下面这样:
f[..., :p_index - 1] = f[..., 1:p_index]
这个操作将f张量的部分数值进行左移,我在模型训练的时候还能正常跑,但是当我将模型部署到项目中时,这行代码报错了!
Traceback (most recent call last):File "<input>", line 1, in <module>
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
这个PyTorch报错是因为在执行操作时,输入张量和目标张量共享了同一块内存地址(存在内存重叠),导致PyTorch无法安全地完成原地(in-place)操作。
既然这样的话为什么在模型训练的时候不会这样呢?后面我仔细研究了一下午,发现了下面的原因:
当我们模型在训练阶段中,f的形状通常是(B,F)的形式存在的,而在部署的时候,作推理时数据通常是(1,F)的形式,所以会出现下面的情况:
# 创建高维张量(3维)
f_3d = torch.randn(16, 1, 25)
slice_3d = f_3d[..., 1:24] # 源切片print("高维张量切片是否连续:")
print(slice_3d.is_contiguous()) # 输出 False# 创建一维张量对比
f_1d = torch.randn(1, 1, 25)
slice_1d = f_1d[..., 1:24]print("\n一维张量切片是否连续:")
print(slice_1d.is_contiguous()) # 输出 True
可以看到,当张量是维度大于1时,其在内存中是非连续存储的,而张量维度为1时,其在内存中是连续存储的。对于非连续张量,PyTorch会在赋值时隐式创建临时副本,避免内存覆盖。因此在进行原地赋值时不会报错。
最后,为了加强代码的鲁棒性,我在所有涉及这部分操作的代码后面加上了clone()函数。
f[..., :p_index - 1] = f[..., 1:p_index].clone()