PyTorch --torch.cat张量拼接原理
在 PyTorch 的 torch.cat 函数中,out 参数用于指定输出张量的存储位置。是否使用 out 参数直接影响结果的存储方式和张量的内存行为。以下是详细解释:
- 不使用 out 参数(默认行为)
含义:不提供 out 参数时,torch.cat 会创建一个新的张量来存储拼接后的结果,并返回这个新张量。
特点:
内存分配:PyTorch 会为结果分配新的内存空间。
原张量不变:输入的原始张量(如 tensors 中的张量)不会被修改。
返回新张量:返回的张量是独立的,与输入张量没有内存共享 - 使用 out 参数
含义:通过 out 参数提供一个已存在的张量,torch.cat 将直接将结果写入该张量中,无需创建新张量。
特点:
内存复用:避免分配新内存,直接利用已有张量的内存空间。
原张量被修改:out 指定的张量会被覆盖,其内容会被替换为拼接结果。
形状匹配:out 张量的形状必须与拼接后的结果完全一致,否则会报错。
以下是关于 torch.cat
在不同 dim
下的拼接过程的公式化描述及可视化示例,用分块矩阵的形式呈现:
1. 数学公式化描述
1.1 沿 dim=0
(行方向)拼接
假设:
- 张量 A 的形状为 m × n m\times n m×n
- 张量 B 的形状为 p × n p\times n p×n
- 所有张量在非拼接维度(列数 n n n)必须一致。
拼接后的张量 C 形状为
(
m
+
p
)
×
n
(m + p) \times n
(m+p)×n,公式表示为:
C
=
[
A
B
]
C = \begin{bmatrix}A \\B\end{bmatrix}
C=[AB]
即:
C
i
,
j
=
{
A
i
,
j
,
若
1
≤
i
≤
m
,
B
i
−
m
,
j
,
若
m
+
1
≤
i
≤
m
+
p
.
C_{i,j} = \begin{cases} A_{i,j}, & \text{若 } 1 \leq i \leq m, \\ B_{i-m,j}, & \text{若 } m+1 \leq i \leq m+p. \end{cases}
Ci,j={Ai,j,Bi−m,j,若 1≤i≤m,若 m+1≤i≤m+p.
1.2 沿 dim=1
(水平/列方向)拼接
假设条件:
- 张量 A A A 的形状为 m × n m \times n m×n
- 张量 B B B 的形状为 m × p m \times p m×p
- 两个张量在非拼接维度(行维度 m m m)上必须保持一致
拼接操作:
水平拼接后的张量
C
C
C 形状为
m
×
(
n
+
p
)
m \times (n + p)
m×(n+p),其数学表示为:
C = [ A B ] C = \begin{bmatrix} A & B \end{bmatrix} C=[AB]
元素级定义:
C
i
,
j
=
{
A
i
,
j
当
1
≤
j
≤
n
B
i
,
j
−
n
当
n
+
1
≤
j
≤
n
+
p
C_{i,j} = \begin{cases} A_{i,j} & \text{当 } 1 \leq j \leq n \\ B_{i,j-n} & \text{当 } n+1 \leq j \leq n+p \end{cases}
Ci,j={Ai,jBi,j−n当 1≤j≤n当 n+1≤j≤n+p
维度说明:
- 行维度: m m m(保持不变)
- 列维度: n + p n + p n+p( A A A 和 B B B 列数的总和)
2. 具体示例(使用数值矩阵)
张量拼接示例
示例 1:沿第 0 维度拼接 (dim=0
)
输入张量:
- A = [ 1 2 3 4 ] ( 形状 2 × 2 ) A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \quad (\text{形状 } 2 \times 2) A=[1324](形状 2×2)
- B = [ 5 6 7 8 ] ( 形状 2 × 2 ) B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形状 } 2 \times 2) B=[5768](形状 2×2)
拼接操作:
C
=
concat
(
A
,
B
,
dim
=
0
)
C = \operatorname{concat}(A, B, \text{dim}=0)
C=concat(A,B,dim=0)
输出结果:
C
=
[
1
2
3
4
5
6
7
8
]
(
形状
4
×
2
)
C = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \hline 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形状 } 4 \times 2)
C=
13572468
(形状 4×2)
示例 2:沿第 1 维度拼接 (dim=1
)
输入张量:
- A = [ 1 2 3 4 ] ( 形状 2 × 2 ) A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \quad (\text{形状 } 2 \times 2) A=[1324](形状 2×2)
- B = [ 5 6 7 8 ] ( 形状 2 × 2 ) B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形状 } 2 \times 2) B=[5768](形状 2×2)
拼接操作:
C
=
concat
(
A
,
B
,
dim
=
1
)
C = \operatorname{concat}(A, B, \text{dim}=1)
C=concat(A,B,dim=1)
输出结果:
C
=
[
1
2
5
6
3
4
7
8
]
(
形状
2
×
4
)
C = \begin{bmatrix} 1 & 2 & 5 & 6 \\ 3 & 4 & 7 & 8 \end{bmatrix} \quad (\text{形状 } 2 \times 4)
C=[13245768](形状 2×4)
关键说明
dim=0
表示垂直拼接(沿行方向堆叠)dim=1
表示水平拼接(沿列方向连接)- 拼接维度的大小可以不同,但其他维度必须完全相同(例如
dim=1
拼接时,两个张量的行数必须相等)
以下是用表格形式展示 dim=0
和 dim=1
的拼接结果:
沿 dim=0
拼接
初始张量 ( A ) | 初始张量 ( B ) | 拼接结果 ( C )(dim=0) |
---|---|---|
[[1, 2], | [[5, 6], | [[1, 2], |
[3, 4]] | [7, 8]] | [3, 4], |
形状:2×2 | 形状:2×2 | [5, 6], |
[7, 8]] | ||
形状:4×2 |
沿 dim=1
拼接
初始张量 ( A ) | 初始张量 ( B ) | 拼接结果 ( C )(dim=1) |
---|---|---|
[[1, 2], | [[5, 6], | [[1, 2, 5, 6], |
[3, 4]] | [7, 8]] | [3, 4, 7, 8]] |
形状:2×2 | 形状:2×2 | 形状:2×4 |
3. 关键点总结
-
维度一致性:
dim=0
:所有张量的列数(n
)必须相同。dim=1
:所有张量的行数(m
)必须相同。
-
拼接方向:
dim=0
:垂直方向拼接(行数相加)。dim=1
:水平方向拼接(列数相加)。
-
数学符号表示:
dim=0
: C = [ A B ] C = \begin{bmatrix} A \\ B \end{bmatrix} C=[AB]dim=1
: C = [ A B ] C = \begin{bmatrix} A & B \end{bmatrix} C=[AB]
5. 扩展示例(多维张量)
假设张量为三维(如图像的批处理):
A
∈
R
B
×
C
×
H
×
W
A \in \mathbb{R}^{B \times C \times H \times W}
A∈RB×C×H×W
B
∈
R
B
′
×
C
×
H
×
W
B \in \mathbb{R}^{B' \times C \times H \times W}
B∈RB′×C×H×W
- 拼接
dim=0
(批处理方向):
C ∈ R ( B + B ′ ) × C × H × W C \in \mathbb{R}^{(B+B') \times C \times H \times W} C∈R(B+B′)×C×H×W
以下是三维张量拼接的示例,使用分块矩阵的形式展示沿不同维度(dim=0
, dim=1
, dim=2
)的拼接过程:
三维张量拼接示例
假设两个三维张量 ( A ) 和 ( B ),形状分别为:
-
A
∈
R
2
×
3
×
2
(形状:
2
×
3
×
2
)
A \in \mathbb{R}^{2 \times 3 \times 2}(形状:2×3×2)
A∈R2×3×2(形状:2×3×2)
-
B
∈
R
1
×
3
×
2
(形状:
1
×
3
×
2
)
B \in \mathbb{R}^{1 \times 3 \times 2} (形状:1×3×2)
B∈R1×3×2(形状:1×3×2)
1. 沿 dim=0
拼接(扩展第一个维度)
- 拼接条件:除
dim=0
外,其他维度(3×2)必须一致。 - 拼接结果形状: ( 2 + 1 ) × 3 × 2 = 3 × 3 × 2 (2 + 1) \times 3 \times 2 = 3 \times 3 \times 2 (2+1)×3×2=3×3×2
- 数学表示:
C = [ A 1 A 2 B 1 ] C = \begin{bmatrix} A_{1} \\ A_{2} \\ B_{1} \end{bmatrix} C= A1A2B1
其中 A 1 , A 2 A_1, A_2 A1,A2 是 A A A的两个“块”, B 1 B_1 B1是 B B B的唯一块。
数值示例:
-
输入张量:
- A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= 135246 , 791181012
- B = [ [ 13 14 15 16 17 18 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 17 & 18 \end{bmatrix} \end{bmatrix} B= 131517141618
-
拼接结果(
dim=0
):
C = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] , [ 13 14 15 16 17 18 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix}, & \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 17 & 18 \end{bmatrix} \end{bmatrix} C= 135246 , 791181012 , 131517141618
形状: 3 × 3 × 2 3 \times 3 \times 2 3×3×2
2. 沿 dim=1
拼接(扩展第二个维度)
- 拼接条件:除
dim=1
外,其他维度(2×2)必须一致。 - 假设调整后的张量形状:
-
A
∈
R
2
×
3
×
2
A \in \mathbb{R}^{2 \times 3 \times 2}
A∈R2×3×2
- B ∈ R 2 × 2 × 2 B \in \mathbb{R}^{2 \times 2 \times 2} B∈R2×2×2
-
A
∈
R
2
×
3
×
2
A \in \mathbb{R}^{2 \times 3 \times 2}
A∈R2×3×2
- 拼接结果形状: 2 × ( 3 + 2 ) × 2 = 2 × 5 × 2 2 \times (3 + 2) \times 2 = 2 \times 5 \times 2 2×(3+2)×2=2×5×2
- 数学表示:
C = [ A 1 B 1 A 2 B 2 ] C = \begin{bmatrix} A_{1} & B_{1} \\ A_{2} & B_{2} \end{bmatrix} C=[A1A2B1B2]
其中 ( A 1 , A 2 ) ( A_1, A_2 ) (A1,A2) 是 ( A ) ( A ) (A) 的块, ( B 1 , B 2 ) ( B_1, B_2 ) (B1,B2) 是 B B B 的块。
数值示例:
-
输入张量:
- A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= 135246 , 791181012 - B = [ [ 13 14 15 16 ] , [ 17 18 19 20 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}, & \begin{bmatrix} 17 & 18 \\ 19 & 20 \end{bmatrix} \end{bmatrix} B=[[13151416],[17191820]]
-
拼接结果(
dim=1
):
C = [ [ 1 2 3 4 5 6 ] [ 13 14 15 16 ] , [ 7 8 9 10 11 12 ] [ 17 18 19 20 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \quad \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \quad \begin{bmatrix} 17 & 18 \\ 19 & 20 \end{bmatrix} \end{bmatrix} C= 135246 [13151416], 791181012 [17191820]
形状: 2 × 5 × 2 2 \times 5 \times 2 2×5×2
3. 沿 dim=2
拼接(扩展第三个维度)
- 拼接条件:除
dim=2
外,其他维度(2×3)必须一致。 - 假设调整后的张量形状:
- A ∈ R 2 × 3 × 2 A \in \mathbb{R}^{2 \times 3 \times 2} A∈R2×3×2
- B ∈ R 2 × 3 × 3 B \in \mathbb{R}^{2 \times 3 \times 3} B∈R2×3×3
- 拼接结果形状: ( 2 × 3 × ( 2 + 3 ) = 2 × 3 × 5 ) ( 2 \times 3 \times (2 + 3) = 2 \times 3 \times 5 ) (2×3×(2+3)=2×3×5)
- 数学表示:
C = [ A 1 B 1 A 2 B 2 ] C = \begin{bmatrix} A_{1} & B_{1} \\ A_{2} & B_{2} \end{bmatrix} C=[A1A2B1B2]
其中每个块在第三个维度上拼接。
数值示例:
-
输入张量:
-
A
=
[
[
1
2
3
4
5
6
]
,
[
7
8
9
10
11
12
]
]
A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix}
A=
135246
,
791181012
- B = [ [ 13 14 15 16 17 18 19 20 21 ] , [ 22 23 24 25 26 27 28 29 30 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 & 15 \\ 16 & 17 & 18 \\ 19 & 20 & 21 \end{bmatrix}, & \begin{bmatrix} 22 & 23 & 24 \\ 25 & 26 & 27 \\ 28 & 29 & 30 \end{bmatrix} \end{bmatrix} B= 131619141720151821 , 222528232629242730
-
A
=
[
[
1
2
3
4
5
6
]
,
[
7
8
9
10
11
12
]
]
A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix}
A=
135246
,
791181012
-
拼接结果(
dim=2
):
C = [ [ 1 2 13 14 15 3 4 16 17 18 5 6 19 20 21 ] , [ 7 8 22 23 24 9 10 25 26 27 11 12 28 29 30 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 & 13 & 14 & 15 \\ 3 & 4 & 16 & 17 & 18 \\ 5 & 6 & 19 & 20 & 21 \end{bmatrix}, & \begin{bmatrix} 7 & 8 & 22 & 23 & 24 \\ 9 & 10 & 25 & 26 & 27 \\ 11 & 12 & 28 & 29 & 30 \end{bmatrix} \end{bmatrix} C= 135246131619141720151821 , 791181012222528232629242730
形状:$2 \times 3 \times 5 )
关键点总结
-
维度扩展方向:
dim=0
:增加第一个维度的大小(如批处理大小)。dim=1
:增加第二个维度的大小(如通道数或行数)。dim=2
:增加第三个维度的大小(如列数或深度)。
-
形状一致性:
- 所有输入张量在非拼接维度的形状必须完全一致。
-
应用场景:
dim=0
:合并不同批次的图像数据。dim=1
:在通道维度拼接特征图(如图像处理中的多模态数据)。dim=2
:扩展特征的维度(如时间序列中的时间步)。
通过上述示例和表格,可以直观理解三维张量在不同维度上的拼接逻辑。