当前位置: 首页 > news >正文

pytorch小记(十七):PyTorch 中的 `expand` 与 `repeat`:详解广播机制与复制行为(附详细示例)

pytorch小记(十七):PyTorch 中的 `expand` 与 `repeat`:详解广播机制与复制行为(附详细示例)

    • 🚀 PyTorch 中的 `expand` 与 `repeat`:详解广播机制与复制行为(附详细示例)
      • 🔍 一、基础定义
        • 1. `tensor.expand(*sizes)`
        • 2. `tensor.repeat(*sizes)`
      • 📌 二、维度行为详解
        • 使用 `expand`
        • 使用 `repeat`
      • ⚠️ 三、重点报错案例解释
        • 📌 示例 1:`expand(1, 4)` 报错
        • ✅ 示例 2:`expand(2, 4)` 正确
      • 🔁 四、repeat 的多种使用场景举例
      • 🔍 五、输入维度对 `expand` 和 `repeat` 的影响总结
      • 🎯 六、常见错误总结
      • ✅ 七、维度补齐技巧
      • 🎓 八、结语:如何选择?
    • 问题
      • 1. PyTorch 自动**广播一维 tensor**
      • 2. 和二维 `[1, 2, 3]` 效果一样?
      • 🔎 为什么以前会报错?
    • 📌 总结规律(适用于新版本 PyTorch)


🚀 PyTorch 中的 expandrepeat:详解广播机制与复制行为(附详细示例)

在使用 PyTorch 构建神经网络时,经常会遇到不同维度张量需要对齐的问题,expand()repeat() 就是两种非常常用的方式来处理张量的形状变化。本博客将详细解释两者的区别、作用、使用规则以及典型的报错原因,配合实际例子,帮助你深入理解广播机制。


🔍 一、基础定义

1. tensor.expand(*sizes)
  • 功能:沿指定维度进行“虚拟复制”,不占用额外内存
  • 要求:只能扩展 原始维度中为1的维度,否则会报错。
2. tensor.repeat(*sizes)
  • 功能真正复制数据,生成新的内存区域。
  • 不限制是否为1的维度,任意维度都能复制。

📌 二、维度行为详解

以一个张量为例:

a = torch.tensor([[1], [2]])  # shape: (2, 1)
使用 expand
print(a.expand(2, 3))

结果:

tensor([[1, 1, 1],
        [2, 2, 2]])
  • 第1维为 1,可以扩展成3列。
  • 数据并没有真实复制,只是通过 广播机制 显示为多列。
使用 repeat
print(a.repeat(1, 3))

结果:

tensor([[1, 1, 1],
        [2, 2, 2]])
  • 每一行的元素真实地复制了3份,占用了新内存。

⚠️ 三、重点报错案例解释

📌 示例 1:expand(1, 4) 报错
c = torch.tensor([[7], [8]])  # shape: (2, 1)
print(c.expand(1, 4))

错误原因

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.

解释:

  • 原 tensor 的第0维是2,而你想扩展为1。
  • 非1的维度不能进行expand扩展,会触发报错。

✅ 示例 2:expand(2, 4) 正确
c = torch.tensor([[7], [8]])  # shape: (2, 1)
print(c.expand(2, 4))

输出:

tensor([[7, 7, 7, 7],
        [8, 8, 8, 8]])
  • 第0维是2,不变 ✅
  • 第1维是1,被扩展为4 ✅

🔁 四、repeat 的多种使用场景举例

a = torch.tensor([[1, 2, 3]])  # shape: (1, 3)
print(a.repeat(2, 3))

输出:

tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])

解释:

  • (2, 3) 的含义是:行重复2次,列重复3次。
  • 数据真实复制!

🔍 五、输入维度对 expandrepeat 的影响总结

操作输入维度形状输入参数说明
expand必须是显式维度尺寸必须与原tensor维度数一致,且非1的维度不能变
repeat任意形状每个维度对应复制几次
自动广播可扩展1维为任意数目expand底层用到
内存行为不复制数据expand 是 zero-copy
内存行为真正复制repeat 用得多就要小心内存

🎯 六、常见错误总结

错误场景示例错误原因
expand 维度不对tensor(2, 1).expand(1, 4)非1维度不能扩展
expand 维数不匹配tensor(2, 1).expand(4)参数数目与维度数不一致
repeat 维度数对不上tensor(2, 1).repeat(3)参数不够,需要补齐

✅ 七、维度补齐技巧

有时原始张量的维度太少,需要先 .unsqueeze() 添加维度:

x = torch.tensor([1, 2, 3])   # shape: (3,)
x = x.unsqueeze(0)            # shape: (1, 3)
x = x.expand(2, 3)

🎓 八、结语:如何选择?

  • 如果你只是想“假装复制”以减少内存开销 ➜ expand()
  • 如果你真的需要重复数据去喂模型 ➜ repeat()
  • 如果你想安全无脑复制 ➜ repeat() 更通用但代价大
  • 如果你要配合 broadcasting ➜ expand() 是你的最优选择

问题

a = torch.tensor([[1, 2, 3]])  # shape: (1, 3)
print(a.shape)
print(a.repeat(6, 4))

a = torch.tensor([1, 2, 3])  # shape: (1, 3)
print(a.shape)
print(a.repeat(6, 4))

为什么维度不同但是输出是一样的?

1. PyTorch 自动广播一维 tensor

在新版 PyTorch 中(大约 1.8 起),当你对 一维张量 调用 .repeat(m, n),PyTorch 会自动地把它当作 shape 为 (1, 3),然后再执行 repeat。这相当于隐式地:

a = torch.tensor([1, 2, 3])    # shape: (3,)
a = a.unsqueeze(0)             # shape: (1, 3)
print(a.repeat(6, 4))          # 🔁 repeat(6, 4) 等价于 (6 rows, 12 columns)

2. 和二维 [1, 2, 3] 效果一样?

是的。你对比的两个 tensor:

a1 = torch.tensor([[1, 2, 3]])  # shape: (1, 3)
a2 = torch.tensor([1, 2, 3])    # shape: (3,)
print(a1.repeat(6, 4))
print(a2.repeat(6, 4))  # 现在两者结果完全一致!

输出都是 shape: (6, 12),值为重复的 [1, 2, 3]

tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        ...
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])

🔎 为什么以前会报错?

在早期版本的 PyTorch 中(<1.8),repeat(6, 4) 要求参数个数和维度完全一致。所以对 a = torch.tensor([1,2,3])(一维)来说,你只能:

a.repeat(6)  # 正确,对一维张量
a.repeat(6, 4)  # 错误(旧版本)

📌 总结规律(适用于新版本 PyTorch)

原始 tensorrepeat 维度自动行为结果
[1,2,3] (1维)repeat(6,4)自动 unsqueeze → (1,3)
[[1,2,3]](2维)repeat(6,4)直接 repeat
[1,2,3](1维)repeat(6)沿第0维重复
[[1,2,3]](2维)repeat(6)报错,维度不匹配

相关文章:

  • [定位器]晶艺LA1823,4.5V~100V, 3.5A,替换MP9487,MP9486A,启烨科技
  • 2025.4.9总结
  • c++比较器——priority_queue用 ; unordered_map 自定义哈希函数
  • 基于Redis实现短信防轰炸的Java解决方案
  • 唯一分解定理
  • 系统与网络安全------网络通信原理(4)
  • 每日算法:洛谷U535992 J-C 小梦的宝石收集(双指针、二分)
  • 金融级隐私安全之DeepSeek R1 模型去中心化存储和推理实现方案
  • python爬虫:喜马拉雅案例(破解sign值)
  • 以库存系统为核心的ERP底层架构设计
  • git单独跟踪远程分支及处理合并异常情况
  • 蓝桥杯嵌入式第十五届
  • C++【string类】(一)
  • C语言关键字
  • 认识 Linux 内存构成:Linux 内存调优之虚拟内存与物理内存
  • DAY 39 leetcode 18--哈希表.四数之和
  • 支持企业知识库和联网搜索,360AI企业知识库驱动业务深度融合
  • 小刚说C语言刷题——第19讲 循环之continue和break
  • 计算机网络(1)
  • 【SpringCloud】从入门到精通(上)
  • 网站项目建设主要内容/手机网址大全123客户端下载
  • 网站模式下做淘宝客/廊坊百度推广电话
  • nba新闻那个网站做的好/网络运营推广
  • wordpress站标签也打不开/参考消息今天新闻
  • 网站建设app销售好做吗/拓客平台有哪些
  • 武汉网页平面设计/网站推广与优化平台