当前位置: 首页 > news >正文

Pytorch中expand()和repeat()函数使用详解和实战示例

在 PyTorch 中,expand()repeat() 都用于张量维度的扩展与复制,但它们的原理和内存使用方式不同,适用于不同场景。


1、 expand():广播扩展(不复制数据)

功能:

expand() 返回一个视图(view),通过 广播机制 将张量在指定维度“扩展”,不复制内存

限制:

只能在原始维度为 1 的轴上扩展,不能创建新的维度,也不能在非 1 的维度扩展。


示例:

import torchx = torch.tensor([[1], [2], [3]])  # shape: (3, 1)# 扩展第二维到4
y = x.expand(3, 4)print("x:\n", x)
print("y:\n", y)

输出:

x:tensor([[1],[2],[3]])
y:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])

注意:

虽然看起来 y 是复制的,但它没有真正复制数据,只是广播视图,不占用额外内存。


2、repeat():真实复制张量内容

功能:

repeat() 沿指定维度进行实际数据复制,得到一个新的张量。与 expand() 不同,它复制数据,占用更多内存,但灵活性更强。


示例:

x = torch.tensor([[1], [2], [3]])  # shape: (3, 1)# 在第0维重复1次,第1维重复4次
y = x.repeat(1, 4)  # shape: (3, 4)print("x:\n", x)
print("y:\n", y)

输出:

x:tensor([[1],[2],[3]])
y:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])

3、 对比总结

特性expand()repeat()
是否复制内存❌ 否,仅创建视图✅ 是,真正复制数据
是否支持广播✅ 只能在维度为 1 的轴广播❌ 直接复制,不依赖维度是否为1
内存开销小(共享内存)大(复制数据)
灵活性限制多,效率高更灵活但效率低
常用于场景batch 中广播参数、掩码构造等构造 tile 模式张量、数据重复

4、 实战对比:

x = torch.tensor([1, 2, 3])  # shape: (3,)# reshape 成 (3,1) 才能 expand
x1 = x.view(3, 1).expand(3, 4)
x2 = x.view(3, 1).repeat(1, 4)print("expand:\n", x1)
print("repeat:\n", x2)

输出:

expand:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])
repeat:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])

虽然结果相同,但 expand() 的效率更高(无数据复制)。

如果希望根据具体场景(如节省内存 vs 需要独立复制)选择操作,建议:

  • 广播 → expand()
  • 真正数据复制 → repeat()

5、实战示例

下面分别通过两个实际应用场景来展示 expand()repeat() 的用法:


1. BERT 中 attention mask 构造

BERT 中的 self-attention 通常需要一个 attention_mask,它的形状是 (batch_size, 1, 1, seq_len)(batch_size, 1, seq_len, seq_len),需要通过 expand() 扩展广播。


示例场景:

你有一个 batch 的文本序列,每个位置为 1 表示有效,0 表示 padding:

import torch# 假设一个 batch 的 attention mask (batch_size=2, seq_len=4)
mask = torch.tensor([[1, 1, 1, 0],[1, 1, 0, 0]])  # shape: (2, 4)

我们需要将其变成 (2, 1, 1, 4) 用于 broadcasting:

# 加维度:batch_size x 1 x 1 x seq_len
mask_expanded = mask.unsqueeze(1).unsqueeze(2)  # shape: (2,1,1,4)# 假设 query length 也为 4,我们要 broadcast 到 (2, 1, 4, 4)
attn_mask = mask_expanded.expand(-1, 1, 4, -1)  # -1 表示保持原维度print(attn_mask.shape)
print(attn_mask)

输出:
torch.Size([2, 1, 4, 4])
tensor([[[[1, 1, 1, 0],[1, 1, 1, 0],[1, 1, 1, 0],[1, 1, 1, 0]]],[[[1, 1, 0, 0],[1, 1, 0, 0],[1, 1, 0, 0],[1, 1, 0, 0]]]])

这就是 BERT 中构造多头注意力掩码时典型使用 expand() 的方式。


2. 图像 tile:使用 repeat() 扩展图像张量

例如我们有一个灰度图像 1x1x28x28(batch_size=1, channel=1),我们希望将这个图像 横向复制 2 次、纵向复制 3 次,形成一个大的拼接图像。


示例代码:
import torch# 创建一个伪图像:1个通道,28x28
img = torch.arange(28*28).reshape(1, 1, 28, 28).float()# 纵向重复3次,横向重复2次
# repeat参数: batch, channel, height_repeat, width_repeat
tiled_img = img.repeat(1, 1, 3, 2)  # 形状: (1, 1, 84, 56)print(tiled_img.shape)

输出:
torch.Size([1, 1, 84, 56])

此操作实际复制了图像数据,每个像素在目标 tensor 中占有真实空间,可以用于拼接生成大图或训练 tile-based 图像模型。


3、 总结对比

场景操作函数原因
BERT attention maskexpand()避免内存复制,适合广播掩码
图像 tile 复制repeat()必须真实复制图像内容

6、 多头注意力中 mask 对多头扩展的写法(补充资料)

多头注意力(Multi-head Attention) 中,为了让不同的头共享相同的 attention_mask,我们通常需要将 mask 的 shape 从 (batch_size, seq_len) 扩展成 (batch_size, num_heads, seq_len, seq_len)

这个扩展操作通常组合使用 unsqueeze()expand()repeat(),根据具体实现选择是否复制内存。


1、场景设定

我们有:

  • batch size = B
  • sequence length = L
  • number of heads = H

输入的原始 attention mask:

# 原始 padding mask:B × L
mask = torch.tensor([[1, 1, 1, 0],[1, 1, 0, 0]
])  # shape: (2, 4)

2、 多头注意力中扩展 mask 的方法

方法 1:使用 unsqueeze + expand不复制数据,更高效)
B, L, H = 2, 4, 8  # batch, seq_len, num_heads# 1. 原始 mask shape: (B, L)
# 2. 先加两个维度:B × 1 × 1 × L
mask = mask.unsqueeze(1).unsqueeze(2)  # shape: (2, 1, 1, 4)# 3. 扩展到多头:B × H × L × L
mask = mask.expand(B, H, L, L)print(mask.shape)

输出:

torch.Size([2, 8, 4, 4])

这个扩展方式非常适合 BERT 等 Transformer 模型中的 attention_mask 构造,不会额外占用内存。


方法 2:使用 repeat复制数据

如果你希望生成独立的副本(比如 mask 后面会被修改),可以用 repeat

# 先加两个维度:B × 1 × 1 × L
mask = mask.unsqueeze(1).unsqueeze(2)  # (2,1,1,4)# 重复到多头:B × H × L × L
mask = mask.repeat(1, H, L, 1)  # batch维不变,头重复H次,query和key维重复

3、用在 Attention 权重前:

# 假设 attention logits shape: (B, H, L, L)
attn_logits = torch.randn(B, H, L, L)# mask == 0 的地方,我们不希望注意力流动,可设为 -inf(或 -1e9)
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))

然后将 attn_logits 送入 softmax。


4、 总结:选择 expand or repeat

场景推荐操作原因
构造 attention mask(只读)expand()高效、无数据复制
构造后需要修改repeat()每个位置是独立内存
http://www.dtcms.com/a/265587.html

相关文章:

  • github在线图床
  • 一篇文章掌握Docker
  • Redis 持久化详解、使用及注意事项
  • 关于使用cursor tunnel链接vscode(避免1006 issue的做法)
  • ASP 安装使用教程
  • ubuntu rules 使用规则
  • 什么是VR全景展示?VR展示需要哪些科技?
  • 【React Native原生项目不能运行npx react-native run-android项目】
  • 学习设计模式《十六》——策略模式
  • 安装 Docker Compose!!!
  • 蒙特卡洛方法:随机抽样的艺术与科学
  • SSL Pinning破解实战:企业级移动应用安全测试方案
  • java集合详解
  • 论文阅读笔记——Autoregressive Image Generation without Vector Quantization
  • 当材料研发遇上「数字集装箱」:Docker如何让科研效率「开挂」?
  • 【unity游戏开发——优化篇】使用Occlusion Culling遮挡剔除,只渲染相机视野内的游戏物体提升游戏性能
  • AES密码算法的C语言实现(带测试)
  • 经典灰狼算法+编码器+双向长短期记忆神经网络,GWO-Transformer-BiLSTM多变量回归预测,作者:机器学习之心!
  • 【TTS】2024-2025年主流开源TTS模型的综合对比分析
  • 仿星露谷物语开发总结VIP(Unity高级编程知识)
  • RabbitMQ 通过HTTP API删除队列命令
  • 【RK3568+PG2L50H开发板实验例程】Linux部分/FPGA FSPI 通信案例
  • 【机器学习深度学习】什么是下游任务模型?
  • laravel基础:php artisan make:model Flight --all 详解
  • 【PaddleOCR】OCR文本检测与文本识别数据集整理,持续更新......
  • 【QT】QWidget控件详解 || 常用的API
  • 蓝桥杯C++组算法知识点整理 · 考前突击(中)【小白适用】
  • Java调用百度地图天气查询服务获取当前和未来天气-以贵州省榕江县为例
  • 【字节跳动】数据挖掘面试题0006:SVM(支持向量机)详细原理
  • JVM类加载过程