pytorch中的原地与非原地操作
pytorch中的原地与非原地操作
前言
读ATD源码时,我突然对其中的一个写法产生了疑惑:
def forward(self, x, x_size, params):return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
这种写法真的可以实现残差连接吗?如果self.residual_group修改了x,那
+x
是不是也不是原本的x了?
(注明一下:这个东西要问deepseek R1 671b,AI-4o不明白这个点)
解答
首先明确一个很重要的点:PyTorch中的大多数操作都是 非原地的 ,比如卷积、线性层等,它们会返回新的张量,而不是修改输入
举个例子:
def forward(self, x): x = self.conv(x) # 生成新的张量,原x不变 x = self.relu(x) # 同样生成新张量 return x
什么是 原地(inplace) 操作?
常见的原地操作包括:
- 使用
inplace=True
的激活函数,如nn.ReLU(inplace=True)
。- 直接对张量进行赋值操作,例如
x[:, :, ...] = ...
或x += ...
。- 使用原地函数,如
torch.add_
、torch.mul_
等。
所以,现在可以回答开头的问题了
答案:这个残差是正确的。
阅读代码可以发现,只有self.conv是inplace的,而其他的模块都是非原地的。
self.conv
会修改 输入到self.conv
的张量(即self.patch_unembed
的输出),但 不会影响原始输入x
。因为输入到self.conv
的张量是self.patch_unembed(...)
的结果,与原始x
无关。不过,我觉得写成
def forward(self, x, x_size, params):original_x = x return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + original_x
会更清晰且不容易出错。不然还得仔细检查中间的网络层到底有没有inplace操作
补充
pytorch/python的变量传参是原变量还是引用?值or址?
https://blog.csdn.net/weixin_42264234/article/details/118788863 (这个可以看看)
https://blog.csdn.net/u010167269/article/details/52073136 (这个不用看,没啥用)
我把最重要的点写在这:
def change(val):val.append(100)val = ['T', 'Z', 'Y']
nums = [0, 1]
change(nums)
print(nums)
结果是什么?为什么?
在 Python 中,函数参数的传递是按对象引用传递的(✅)。这意味着函数内部对可变对象(如列表)的修改会影响到函数外部的对象。然而,如果在函数内部重新绑定参数到一个新的对象,这个改变不会影响到函数外部的对象。
让我们逐步分析代码:
nums = [0, 1]
:创建一个列表nums
,其中包含两个元素[0, 1]
。change(nums)
:调用change
函数,并将nums
作为参数传递进去。- 在change函数内部:
val.append(100)
:将100
添加到val
所引用的列表中。此时,val
和nums
仍然引用同一个列表,所以nums
变为[0, 1, 100]
。val = ['T', 'Z', 'Y']
:将val
重新绑定到一个新的列表['T', 'Z', 'Y']
。这个操作不会影响到nums
,因为val
现在引用的是一个新的对象。print(nums)
:打印nums
的值。由于在change
函数内部对val
的修改影响了nums
,所以nums
现在是[0, 1, 100]
。因此,代码的输出是:
[0, 1, 100]
总结:
nums
的值被修改为[0, 1, 100]
,因为在change
函数内部通过val.append(100)
修改了nums
所引用的列表。而val = ['T', 'Z', 'Y']
这一行只是将val
重新绑定到一个新的列表,并没有影响到nums
。