pytorch chunk 切块
目录
chunk切块
chunk切块
import torch# 创建一个形状为 [2, 3, 4] 的张量
x = torch.arange(6).reshape(2, 3)
print("原始张量形状:", x.shape)
print("x:", x)
# 输出: 原始张量形状: torch.Size([2, 3, 4])# 沿着最后一个维度分割成 2 块
chunks = x.chunk(2, dim=-1)
print("分割后的块数量:", len(chunks))
# 输出: 分割后的块数量: 2# 查看每个块的形状
for i, chunk in enumerate(chunks):print(f"块 {i} 的形状:", chunk)
结果:
原始张量形状: torch.Size([2, 3])
x: tensor([[0, 1, 2],
[3, 4, 5]])
分割后的块数量: 2
块 0 的形状: tensor([[0, 1],
[3, 4]])
块 1 的形状: tensor([[2],
[5]])