当前位置: 首页 > news >正文

PyTorch --torch.cat张量拼接原理

在 PyTorch 的 torch.cat 函数中,out 参数用于指定输出张量的存储位置。是否使用 out 参数直接影响结果的存储方式和张量的内存行为。以下是详细解释:

  1. 不使用 out 参数(默认行为)
    含义:不提供 out 参数时,torch.cat 会创建一个新的张量来存储拼接后的结果,并返回这个新张量。
    特点:
    内存分配:PyTorch 会为结果分配新的内存空间。
    原张量不变:输入的原始张量(如 tensors 中的张量)不会被修改。
    返回新张量:返回的张量是独立的,与输入张量没有内存共享
  2. 使用 out 参数
    含义:通过 out 参数提供一个已存在的张量,torch.cat 将直接将结果写入该张量中,无需创建新张量。
    特点:
    内存复用:避免分配新内存,直接利用已有张量的内存空间。
    原张量被修改:out 指定的张量会被覆盖,其内容会被替换为拼接结果。
    形状匹配:out 张量的形状必须与拼接后的结果完全一致,否则会报错。

以下是关于 torch.cat 在不同 dim 下的拼接过程的公式化描述及可视化示例,用分块矩阵的形式呈现:


1. 数学公式化描述

1.1 沿 dim=0(行方向)拼接

假设:

  • 张量 A 的形状为 m × n m\times n m×n
  • 张量 B 的形状为 p × n p\times n p×n
  • 所有张量在非拼接维度(列数 n n n)必须一致。

拼接后的张量 C 形状为 ( m + p ) × n (m + p) \times n (m+p)×n,公式表示为: C = [ A B ] C = \begin{bmatrix}A \\B\end{bmatrix} C=[AB]
即:
C i , j = { A i , j , 若  1 ≤ i ≤ m , B i − m , j , 若  m + 1 ≤ i ≤ m + p . C_{i,j} = \begin{cases} A_{i,j}, & \text{若 } 1 \leq i \leq m, \\ B_{i-m,j}, & \text{若 } m+1 \leq i \leq m+p. \end{cases} Ci,j={Ai,j,Bim,j, 1im, m+1im+p.


1.2 沿 dim=1(水平/列方向)拼接

假设条件

  • 张量 A A A 的形状为 m × n m \times n m×n
  • 张量 B B B 的形状为 m × p m \times p m×p
  • 两个张量在非拼接维度(行维度 m m m)上必须保持一致

拼接操作
水平拼接后的张量 C C C 形状为 m × ( n + p ) m \times (n + p) m×(n+p),其数学表示为:

C = [ A B ] C = \begin{bmatrix} A & B \end{bmatrix} C=[AB]

元素级定义
C i , j = { A i , j 当  1 ≤ j ≤ n B i , j − n 当  n + 1 ≤ j ≤ n + p C_{i,j} = \begin{cases} A_{i,j} & \text{当 } 1 \leq j \leq n \\ B_{i,j-n} & \text{当 } n+1 \leq j \leq n+p \end{cases} Ci,j={Ai,jBi,jn 1jn n+1jn+p

维度说明

  • 行维度: m m m(保持不变)
  • 列维度: n + p n + p n+p A A A B B B 列数的总和)

2. 具体示例(使用数值矩阵)

张量拼接示例

示例 1:沿第 0 维度拼接 (dim=0)

输入张量

  • A = [ 1 2 3 4 ] ( 形状  2 × 2 ) A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \quad (\text{形状 } 2 \times 2) A=[1324](形状 2×2)
  • B = [ 5 6 7 8 ] ( 形状  2 × 2 ) B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形状 } 2 \times 2) B=[5768](形状 2×2)

拼接操作
C = concat ⁡ ( A , B , dim = 0 ) C = \operatorname{concat}(A, B, \text{dim}=0) C=concat(A,B,dim=0)

输出结果
C = [ 1 2 3 4 5 6 7 8 ] ( 形状  4 × 2 ) C = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \hline 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形状 } 4 \times 2) C= 13572468 (形状 4×2)


示例 2:沿第 1 维度拼接 (dim=1)

输入张量

  • A = [ 1 2 3 4 ] ( 形状  2 × 2 ) A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \quad (\text{形状 } 2 \times 2) A=[1324](形状 2×2)
  • B = [ 5 6 7 8 ] ( 形状  2 × 2 ) B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形状 } 2 \times 2) B=[5768](形状 2×2)

拼接操作
C = concat ⁡ ( A , B , dim = 1 ) C = \operatorname{concat}(A, B, \text{dim}=1) C=concat(A,B,dim=1)

输出结果
C = [ 1 2 5 6 3 4 7 8 ] ( 形状  2 × 4 ) C = \begin{bmatrix} 1 & 2 & 5 & 6 \\ 3 & 4 & 7 & 8 \end{bmatrix} \quad (\text{形状 } 2 \times 4) C=[13245768](形状 2×4)

关键说明

  1. dim=0 表示垂直拼接(沿行方向堆叠)
  2. dim=1 表示水平拼接(沿列方向连接)
  3. 拼接维度的大小可以不同,但其他维度必须完全相同(例如 dim=1 拼接时,两个张量的行数必须相等)

以下是用表格形式展示 dim=0dim=1 的拼接结果:

沿 dim=0 拼接
初始张量 ( A )初始张量 ( B )拼接结果 ( C )(dim=0)
[[1, 2],[[5, 6],[[1, 2],
[3, 4]] [7, 8]] [3, 4],
形状:2×2形状:2×2 [5, 6],
[7, 8]]
形状:4×2
沿 dim=1 拼接
初始张量 ( A )初始张量 ( B )拼接结果 ( C )(dim=1)
[[1, 2],[[5, 6],[[1, 2, 5, 6],
[3, 4]] [7, 8]] [3, 4, 7, 8]]
形状:2×2形状:2×2形状:2×4

3. 关键点总结

  1. 维度一致性

    • dim=0:所有张量的列数(n)必须相同。
    • dim=1:所有张量的行数(m)必须相同。
  2. 拼接方向

    • dim=0:垂直方向拼接(行数相加)。
    • dim=1:水平方向拼接(列数相加)。
  3. 数学符号表示

    • dim=0 C = [ A B ] C = \begin{bmatrix} A \\ B \end{bmatrix} C=[AB]
    • dim=1 C = [ A B ] C = \begin{bmatrix} A & B \end{bmatrix} C=[AB]

5. 扩展示例(多维张量)

假设张量为三维(如图像的批处理):
A ∈ R B × C × H × W A \in \mathbb{R}^{B \times C \times H \times W} ARB×C×H×W
B ∈ R B ′ × C × H × W B \in \mathbb{R}^{B' \times C \times H \times W} BRB×C×H×W

  • 拼接 dim=0(批处理方向)
    C ∈ R ( B + B ′ ) × C × H × W C \in \mathbb{R}^{(B+B') \times C \times H \times W} CR(B+B)×C×H×W

以下是三维张量拼接的示例,使用分块矩阵的形式展示沿不同维度(dim=0, dim=1, dim=2)的拼接过程:


三维张量拼接示例

假设两个三维张量 ( A ) 和 ( B ),形状分别为:
- A ∈ R 2 × 3 × 2 (形状: 2 × 3 × 2 ) A \in \mathbb{R}^{2 \times 3 \times 2}(形状:2×3×2) AR2×3×2(形状:2×3×2
- B ∈ R 1 × 3 × 2 (形状: 1 × 3 × 2 ) B \in \mathbb{R}^{1 \times 3 \times 2} (形状:1×3×2) BR1×3×2(形状:1×3×2

1. 沿 dim=0 拼接(扩展第一个维度)
  • 拼接条件:除 dim=0 外,其他维度(3×2)必须一致。
  • 拼接结果形状 ( 2 + 1 ) × 3 × 2 = 3 × 3 × 2 (2 + 1) \times 3 \times 2 = 3 \times 3 \times 2 (2+1)×3×2=3×3×2
  • 数学表示
    C = [ A 1 A 2 B 1 ] C = \begin{bmatrix} A_{1} \\ A_{2} \\ B_{1} \end{bmatrix} C= A1A2B1
    其中 A 1 , A 2 A_1, A_2 A1,A2 A A A的两个“块”, B 1 B_1 B1 B B B的唯一块。
数值示例
  • 输入张量

    • A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= 135246 , 791181012
    • B = [ [ 13 14 15 16 17 18 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 17 & 18 \end{bmatrix} \end{bmatrix} B= 131517141618
  • 拼接结果(dim=0
    C = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] , [ 13 14 15 16 17 18 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix}, & \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 17 & 18 \end{bmatrix} \end{bmatrix} C= 135246 , 791181012 , 131517141618
    形状 3 × 3 × 2 3 \times 3 \times 2 3×3×2


2. 沿 dim=1 拼接(扩展第二个维度)
  • 拼接条件:除 dim=1 外,其他维度(2×2)必须一致。
  • 假设调整后的张量形状
    • A ∈ R 2 × 3 × 2 A \in \mathbb{R}^{2 \times 3 \times 2} AR2×3×2
      - B ∈ R 2 × 2 × 2 B \in \mathbb{R}^{2 \times 2 \times 2} BR2×2×2
  • 拼接结果形状 2 × ( 3 + 2 ) × 2 = 2 × 5 × 2 2 \times (3 + 2) \times 2 = 2 \times 5 \times 2 2×(3+2)×2=2×5×2
  • 数学表示
    C = [ A 1 B 1 A 2 B 2 ] C = \begin{bmatrix} A_{1} & B_{1} \\ A_{2} & B_{2} \end{bmatrix} C=[A1A2B1B2]
    其中 ( A 1 , A 2 ) ( A_1, A_2 ) (A1,A2) ( A ) ( A ) (A) 的块, ( B 1 , B 2 ) ( B_1, B_2 ) (B1,B2) B B B 的块。
数值示例
  • 输入张量
    - A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= 135246 , 791181012

    • B = [ [ 13 14 15 16 ] , [ 17 18 19 20 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}, & \begin{bmatrix} 17 & 18 \\ 19 & 20 \end{bmatrix} \end{bmatrix} B=[[13151416],[17191820]]
  • 拼接结果(dim=1
    C = [ [ 1 2 3 4 5 6 ] [ 13 14 15 16 ] , [ 7 8 9 10 11 12 ] [ 17 18 19 20 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \quad \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \quad \begin{bmatrix} 17 & 18 \\ 19 & 20 \end{bmatrix} \end{bmatrix} C= 135246 [13151416], 791181012 [17191820]
    形状 2 × 5 × 2 2 \times 5 \times 2 2×5×2


3. 沿 dim=2 拼接(扩展第三个维度)
  • 拼接条件:除 dim=2 外,其他维度(2×3)必须一致。
  • 假设调整后的张量形状
    • A ∈ R 2 × 3 × 2 A \in \mathbb{R}^{2 \times 3 \times 2} AR2×3×2
    • B ∈ R 2 × 3 × 3 B \in \mathbb{R}^{2 \times 3 \times 3} BR2×3×3
  • 拼接结果形状 ( 2 × 3 × ( 2 + 3 ) = 2 × 3 × 5 ) ( 2 \times 3 \times (2 + 3) = 2 \times 3 \times 5 ) (2×3×(2+3)=2×3×5)
  • 数学表示
    C = [ A 1 B 1 A 2 B 2 ] C = \begin{bmatrix} A_{1} & B_{1} \\ A_{2} & B_{2} \end{bmatrix} C=[A1A2B1B2]
    其中每个块在第三个维度上拼接。
数值示例
  • 输入张量

    • A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= 135246 , 791181012
      - B = [ [ 13 14 15 16 17 18 19 20 21 ] , [ 22 23 24 25 26 27 28 29 30 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 & 15 \\ 16 & 17 & 18 \\ 19 & 20 & 21 \end{bmatrix}, & \begin{bmatrix} 22 & 23 & 24 \\ 25 & 26 & 27 \\ 28 & 29 & 30 \end{bmatrix} \end{bmatrix} B= 131619141720151821 , 222528232629242730
  • 拼接结果(dim=2
    C = [ [ 1 2 13 14 15 3 4 16 17 18 5 6 19 20 21 ] , [ 7 8 22 23 24 9 10 25 26 27 11 12 28 29 30 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 & 13 & 14 & 15 \\ 3 & 4 & 16 & 17 & 18 \\ 5 & 6 & 19 & 20 & 21 \end{bmatrix}, & \begin{bmatrix} 7 & 8 & 22 & 23 & 24 \\ 9 & 10 & 25 & 26 & 27 \\ 11 & 12 & 28 & 29 & 30 \end{bmatrix} \end{bmatrix} C= 135246131619141720151821 , 791181012222528232629242730
    形状:$2 \times 3 \times 5 )


关键点总结

  1. 维度扩展方向

    • dim=0:增加第一个维度的大小(如批处理大小)。
    • dim=1:增加第二个维度的大小(如通道数或行数)。
    • dim=2:增加第三个维度的大小(如列数或深度)。
  2. 形状一致性

    • 所有输入张量在非拼接维度的形状必须完全一致。
  3. 应用场景

    • dim=0:合并不同批次的图像数据。
    • dim=1:在通道维度拼接特征图(如图像处理中的多模态数据)。
    • dim=2:扩展特征的维度(如时间序列中的时间步)。

通过上述示例和表格,可以直观理解三维张量在不同维度上的拼接逻辑。

相关文章:

  • 前端er在Cursor使用MCP实现精选照片的快速上手教程
  • AISTATS 2025 | ChronosX:利用外生变量调整预训练时间序列模型
  • Fnos 飞牛Nas安装桌面环境 gnome和KDE桌面- All in One 笔记~1
  • Dubbo(25)如何配置Dubbo的协议和端口?
  • 【减小图片打包体积】image-webpack-loader
  • MySQL--数据备份
  • 实时数据流处理利器:Apache Storm 在大数据中的应用
  • .Net中对称加密的实现
  • cJSON类型及type值详解
  • ECharts系列: Vue 中使用 ECharts 折线图时,怎么配置来实现默认隐藏某些图例,并在用户点击图例时显示或隐藏对应的数据系列
  • MySQL的事务
  • Springboot3.x集成Spring Batch 5.2.1
  • 面试经典150题·LeetCode26·删除有序数组中的重复项·Java
  • 18.redis基本操作
  • 内积相似系数——内积度量相似系数
  • html 列表循环滚动,动态初始化字段数据
  • Android 隐藏手势模式下输入法的BackButton和ImeSwitchButton
  • Vue项目中Vuex在util引入,断点存在default
  • EI复现:蜣螂优化算法变体合集上新,改进正弦算法引导的蜣螂优化算法
  • ts中 构造器
  • 专业做财经直播网站/做网站要多少钱
  • 武汉大学人民医院招聘/北仑seo排名优化技术
  • 商业政府网站cms/百度推广代理公司
  • 自己做电商网站./国家职业技能培训学校
  • 电子商务网站开发形式有/重庆seo排名
  • 做网站用主机/搜索引擎营销sem包括