pytorch中的原地与非原地操作
pytorch中的原地与非原地操作
前言
读ATD源码时,我突然对其中的一个写法产生了疑惑:
def forward(self, x, x_size, params):return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x这种写法真的可以实现残差连接吗?如果self.residual_group修改了x,那
+x是不是也不是原本的x了?
(注明一下:这个东西要问deepseek R1 671b,AI-4o不明白这个点)
解答
首先明确一个很重要的点:PyTorch中的大多数操作都是 非原地的 ,比如卷积、线性层等,它们会返回新的张量,而不是修改输入
举个例子:
def forward(self, x): x = self.conv(x) # 生成新的张量,原x不变 x = self.relu(x) # 同样生成新张量 return x
什么是 原地(inplace) 操作?
常见的原地操作包括:
- 使用
inplace=True的激活函数,如nn.ReLU(inplace=True)。- 直接对张量进行赋值操作,例如
x[:, :, ...] = ...或x += ...。- 使用原地函数,如
torch.add_、torch.mul_等。
所以,现在可以回答开头的问题了
答案:这个残差是正确的。
阅读代码可以发现,只有self.conv是inplace的,而其他的模块都是非原地的。
self.conv会修改 输入到self.conv的张量(即self.patch_unembed的输出),但 不会影响原始输入x。因为输入到self.conv的张量是self.patch_unembed(...)的结果,与原始x无关。不过,我觉得写成
def forward(self, x, x_size, params):original_x = x return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + original_x会更清晰且不容易出错。不然还得仔细检查中间的网络层到底有没有inplace操作
补充
pytorch/python的变量传参是原变量还是引用?值or址?
https://blog.csdn.net/weixin_42264234/article/details/118788863 (这个可以看看)
https://blog.csdn.net/u010167269/article/details/52073136 (这个不用看,没啥用)
我把最重要的点写在这:
def change(val):val.append(100)val = ['T', 'Z', 'Y']
nums = [0, 1]
change(nums)
print(nums)
结果是什么?为什么?
在 Python 中,函数参数的传递是按对象引用传递的(✅)。这意味着函数内部对可变对象(如列表)的修改会影响到函数外部的对象。然而,如果在函数内部重新绑定参数到一个新的对象,这个改变不会影响到函数外部的对象。
让我们逐步分析代码:
nums = [0, 1]:创建一个列表nums,其中包含两个元素[0, 1]。change(nums):调用change函数,并将nums作为参数传递进去。- 在change函数内部:
val.append(100):将100添加到val所引用的列表中。此时,val和nums仍然引用同一个列表,所以nums变为[0, 1, 100]。val = ['T', 'Z', 'Y']:将val重新绑定到一个新的列表['T', 'Z', 'Y']。这个操作不会影响到nums,因为val现在引用的是一个新的对象。print(nums):打印nums的值。由于在change函数内部对val的修改影响了nums,所以nums现在是[0, 1, 100]。因此,代码的输出是:
[0, 1, 100]总结:
nums的值被修改为[0, 1, 100],因为在change函数内部通过val.append(100)修改了nums所引用的列表。而val = ['T', 'Z', 'Y']这一行只是将val重新绑定到一个新的列表,并没有影响到nums。
