当前位置: 首页 > news >正文

torch.cat和torch.stack的区别

torch.cat 和 torch.stack 是 PyTorch 中用于组合张量的两个常用函数,它们的核心区别在于输入张量的维度和输出张量的维度变化。以下是详细对比:

1. torch.cat (Concatenate)

  • 作用:沿现有维度拼接多个张量,不创建新维度
  • 输入要求:所有张量的形状必须除拼接维度外完全相同

  • 语法

    torch.cat(tensors, dim=0)  # dim 指定拼接的维度
  • 示例

    a = torch.tensor([[1, 2], [3, 4]])  # shape (2, 2)
    b = torch.tensor([[5, 6]])           # shape (1, 2)
    
    # 沿 dim=0 拼接(行方向)
    c = torch.cat([a, b], dim=0)
    print(c)
    # tensor([[1, 2],
    #         [3, 4],
    #         [5, 6]])  # shape (3, 2)
  • 特点
    • 拼接后的张量在指定维度上的大小是输入张量该维度大小的总和。

    • 其他维度必须完全一致。

2. torch.stack

  • 作用:沿新维度堆叠多个张量,创建新维度

  • 输入要求:所有张量的形状必须完全相同

  • 语法

    torch.stack(tensors, dim=0)  # dim 指定新维度的位置
  • 示例

    a = torch.tensor([1, 2])  # shape (2,)
    b = torch.tensor([3, 4])  # shape (2,)
    
    # 沿新维度 dim=0 堆叠
    c = torch.stack([a, b], dim=0)
    print(c)
    # tensor([[1, 2],
    #         [3, 4]])  # shape (2, 2)
    
    # 沿新维度 dim=1 堆叠
    d = torch.stack([a, b], dim=1)
    print(d)
    # tensor([[1, 3],
    #         [2, 4]])  # shape (2, 2)
  • 特点

    • 输出张量比输入张量多一个维度

    • 适用于将多个相同形状的张量合并为批次(如 batch_size 维度)。

3. 关键区别总结

4. 直观对比示例

假设有两个张量:

x = torch.tensor([1, 2])  # shape (2,)
y = torch.tensor([3, 4])  # shape (2,)

torch.cat 结果

torch.cat([x, y], dim=0)  # tensor([1, 2, 3, 4]), shape (4,)

torch.stack 结果

torch.stack([x, y], dim=0)  # tensor([[1, 2], [3, 4]]), shape (2, 2)

5. 如何选择?

  • 用 torch.cat 当需要扩展现有维度(如拼接多个特征图)。

  • 用 torch.stack 当需要创建新维度(如构建批次数据或堆叠不同模型的输出)

通过理解两者的维度变化逻辑,可以避免常见的形状错误(如 size mismatch)。 

相关文章:

  • 应急响应靶机-Linux(1)
  • 数据结构*包装类泛型
  • C语言进阶之指针
  • CMD命令行笔记
  • 数据库实验:分组查询与聚集函数的使用
  • Vue3状态管理深度实战:Pinia架构设计与企业级应用
  • C#核心学习(十六)面向对象--关联知识点(2)string和Stringbuilder
  • 案例驱动的 IT 团队管理:创新与突破之路: 第四章 危机应对:从风险预见到创新破局-4.1.3重构过程中的团队士气管理
  • 202524 | 分布式事务
  • 《基于 RNN 的股票预测模型代码优化:从重塑到直接可视化》
  • 设计模式——抽象工厂模式总结
  • AcWing 6093. 不互质子序列
  • ubuntu 安装samba
  • 奇葩问题:PGPOOL自动容灾切换,主备不生效原因
  • 部署YUM仓库
  • 卡码网55:右旋字符串
  • 【android bluetooth 框架分析 02】【Module详解 3】【HciHal 模块介绍】
  • 【tRPC-go】message、context相关源码设计思路
  • vue 入门:生命周期
  • 【第16届蓝桥杯C++C组】--- 数位倍数
  • 新能源车盈利拐点:8家上市车企去年合计净利854亿元,多家扭亏
  • 美国第一季度经济环比萎缩0.3%
  • 浪尖计划再出发:万亿之城2030课题组赴九城调研万亿产业
  • “80后”商洛市委副书记、市政府党组副书记赵孝任商洛市副市长
  • 应急管理部派出工作组赴山西太原小区爆炸现场指导救援处置
  • 京津冀“飘絮之困”如何破解?专家坦言仍面临关键技术瓶颈