pytorch部分函数理解
维度转换函数
rearrange函数
import torch
from einops import rearrangedata = torch.range(1, 25)
print(data)
data1 = rearrange(data, '(a b) -> a b', a=5, b=5)
data2 = rearrange(data, '(b a) -> a b', a=5, b=5)
print(data1)
print(data2)
对于data1,可以理解为按行展开;对于data2,则理解为按列展开。对data1做一个转置即可得到data2。
torch.view()和torch.reshape()函数
torch.view函数
data = torch.range(1, 12)
data1=data.view(2,6)
data2=data.view(3,4)
print(data)
print(data1)
print(data2)
torch.reshape函数
data = torch.range(1, 12)
data1=data.reshape(2,6)
data2=data.reshape(3,4)
print(data)
print(data1)
print(data2)
两者输出都是一样的
这些常见的维度转换函数默认都是按照**行展开
**的。
unfold函数和fold函数
unfold函数的输入数据是四维,但输出是三维的。假设输入数据是[B, C, H, W], 那么输出数据是 [B, C* kH * kW, L], 其中 K H K_H KH是核的高, K W K_W KW是核宽。
L则是这个高kH宽kW的核,能在H*W区域按照指定stride滑动的次数。
L = ( H − K H + 1 ) ∗ ( W − K W + 1 ) L=(H-K_H+1)*(W-K_W+1) L=(H−KH+1)∗(W−KW+1)