Pytorch中expand()和repeat()函数使用详解和实战示例
在 PyTorch 中,expand()
和 repeat()
都用于张量维度的扩展与复制,但它们的原理和内存使用方式不同,适用于不同场景。
1、 expand()
:广播扩展(不复制数据)
功能:
expand()
返回一个视图(view),通过 广播机制 将张量在指定维度“扩展”,不复制内存。
限制:
只能在原始维度为 1
的轴上扩展,不能创建新的维度,也不能在非 1 的维度扩展。
示例:
import torchx = torch.tensor([[1], [2], [3]]) # shape: (3, 1)# 扩展第二维到4
y = x.expand(3, 4)print("x:\n", x)
print("y:\n", y)
输出:
x:tensor([[1],[2],[3]])
y:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])
注意:
虽然看起来 y
是复制的,但它没有真正复制数据,只是广播视图,不占用额外内存。
2、repeat()
:真实复制张量内容
功能:
repeat()
沿指定维度进行实际数据复制,得到一个新的张量。与 expand()
不同,它复制数据,占用更多内存,但灵活性更强。
示例:
x = torch.tensor([[1], [2], [3]]) # shape: (3, 1)# 在第0维重复1次,第1维重复4次
y = x.repeat(1, 4) # shape: (3, 4)print("x:\n", x)
print("y:\n", y)
输出:
x:tensor([[1],[2],[3]])
y:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])
3、 对比总结
特性 | expand() | repeat() |
---|---|---|
是否复制内存 | ❌ 否,仅创建视图 | ✅ 是,真正复制数据 |
是否支持广播 | ✅ 只能在维度为 1 的轴广播 | ❌ 直接复制,不依赖维度是否为1 |
内存开销 | 小(共享内存) | 大(复制数据) |
灵活性 | 限制多,效率高 | 更灵活但效率低 |
常用于场景 | batch 中广播参数、掩码构造等 | 构造 tile 模式张量、数据重复 |
4、 实战对比:
x = torch.tensor([1, 2, 3]) # shape: (3,)# reshape 成 (3,1) 才能 expand
x1 = x.view(3, 1).expand(3, 4)
x2 = x.view(3, 1).repeat(1, 4)print("expand:\n", x1)
print("repeat:\n", x2)
输出:
expand:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])
repeat:tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])
虽然结果相同,但 expand()
的效率更高(无数据复制)。
如果希望根据具体场景(如节省内存 vs 需要独立复制)选择操作,建议:
- 广播 →
expand()
- 真正数据复制 →
repeat()
5、实战示例
下面分别通过两个实际应用场景来展示 expand()
和 repeat()
的用法:
1. BERT 中 attention mask
构造
BERT 中的 self-attention 通常需要一个 attention_mask
,它的形状是 (batch_size, 1, 1, seq_len)
或 (batch_size, 1, seq_len, seq_len)
,需要通过 expand()
扩展广播。
示例场景:
你有一个 batch 的文本序列,每个位置为 1 表示有效,0 表示 padding:
import torch# 假设一个 batch 的 attention mask (batch_size=2, seq_len=4)
mask = torch.tensor([[1, 1, 1, 0],[1, 1, 0, 0]]) # shape: (2, 4)
我们需要将其变成 (2, 1, 1, 4)
用于 broadcasting:
# 加维度:batch_size x 1 x 1 x seq_len
mask_expanded = mask.unsqueeze(1).unsqueeze(2) # shape: (2,1,1,4)# 假设 query length 也为 4,我们要 broadcast 到 (2, 1, 4, 4)
attn_mask = mask_expanded.expand(-1, 1, 4, -1) # -1 表示保持原维度print(attn_mask.shape)
print(attn_mask)
输出:
torch.Size([2, 1, 4, 4])
tensor([[[[1, 1, 1, 0],[1, 1, 1, 0],[1, 1, 1, 0],[1, 1, 1, 0]]],[[[1, 1, 0, 0],[1, 1, 0, 0],[1, 1, 0, 0],[1, 1, 0, 0]]]])
这就是 BERT 中构造多头注意力掩码时典型使用 expand()
的方式。
2. 图像 tile:使用 repeat()
扩展图像张量
例如我们有一个灰度图像 1x1x28x28
(batch_size=1, channel=1),我们希望将这个图像 横向复制 2 次、纵向复制 3 次,形成一个大的拼接图像。
示例代码:
import torch# 创建一个伪图像:1个通道,28x28
img = torch.arange(28*28).reshape(1, 1, 28, 28).float()# 纵向重复3次,横向重复2次
# repeat参数: batch, channel, height_repeat, width_repeat
tiled_img = img.repeat(1, 1, 3, 2) # 形状: (1, 1, 84, 56)print(tiled_img.shape)
输出:
torch.Size([1, 1, 84, 56])
此操作实际复制了图像数据,每个像素在目标 tensor 中占有真实空间,可以用于拼接生成大图或训练 tile-based 图像模型。
3、 总结对比
场景 | 操作函数 | 原因 |
---|---|---|
BERT attention mask | expand() | 避免内存复制,适合广播掩码 |
图像 tile 复制 | repeat() | 必须真实复制图像内容 |
6、 多头注意力中 mask 对多头扩展的写法(补充资料)
在 多头注意力(Multi-head Attention) 中,为了让不同的头共享相同的 attention_mask
,我们通常需要将 mask 的 shape 从 (batch_size, seq_len) 扩展成 (batch_size, num_heads, seq_len, seq_len)。
这个扩展操作通常组合使用 unsqueeze()
、expand()
或 repeat()
,根据具体实现选择是否复制内存。
1、场景设定
我们有:
- batch size =
B
- sequence length =
L
- number of heads =
H
输入的原始 attention mask:
# 原始 padding mask:B × L
mask = torch.tensor([[1, 1, 1, 0],[1, 1, 0, 0]
]) # shape: (2, 4)
2、 多头注意力中扩展 mask 的方法
方法 1:使用 unsqueeze + expand
(不复制数据,更高效)
B, L, H = 2, 4, 8 # batch, seq_len, num_heads# 1. 原始 mask shape: (B, L)
# 2. 先加两个维度:B × 1 × 1 × L
mask = mask.unsqueeze(1).unsqueeze(2) # shape: (2, 1, 1, 4)# 3. 扩展到多头:B × H × L × L
mask = mask.expand(B, H, L, L)print(mask.shape)
输出:
torch.Size([2, 8, 4, 4])
这个扩展方式非常适合 BERT 等 Transformer 模型中的 attention_mask
构造,不会额外占用内存。
方法 2:使用 repeat
(复制数据)
如果你希望生成独立的副本(比如 mask 后面会被修改),可以用 repeat
:
# 先加两个维度:B × 1 × 1 × L
mask = mask.unsqueeze(1).unsqueeze(2) # (2,1,1,4)# 重复到多头:B × H × L × L
mask = mask.repeat(1, H, L, 1) # batch维不变,头重复H次,query和key维重复
3、用在 Attention 权重前:
# 假设 attention logits shape: (B, H, L, L)
attn_logits = torch.randn(B, H, L, L)# mask == 0 的地方,我们不希望注意力流动,可设为 -inf(或 -1e9)
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
然后将 attn_logits
送入 softmax。
4、 总结:选择 expand
or repeat
场景 | 推荐操作 | 原因 |
---|---|---|
构造 attention mask(只读) | expand() | 高效、无数据复制 |
构造后需要修改 | repeat() | 每个位置是独立内存 |