PyTorch 中 Tensor 交换维度(transpose、permute、view)详解

在深度学习中,数据的形状(shape)非常重要。
例如,卷积层、全连接层、RNN 都要求输入张量的维度排列符合特定格式。
很多时候,我们需要对张量的维度(dimension)进行 交换、调整或重新排列。
PyTorch 提供了多种方式实现这一点,比如:
tensor.t()tensor.transpose()tensor.permute()tensor.view()/tensor.reshape()
本文将系统讲解这些操作的原理、区别与常见用法。
文章目录
- 一、什么是“交换维度”?
- 二、基础准备
- 三、方法一:`tensor.t()` —— 仅限二维矩阵转置
- 四、方法二:`tensor.transpose(dim0, dim1)` —— 交换两个指定维度
- 五、方法三:`tensor.permute(dims)` —— 任意维度重排
- 六、方法四:`tensor.view()` 与 `tensor.reshape()` —— 改变形状(非严格意义上的交换)
一、什么是“交换维度”?
在 PyTorch 中,张量(Tensor)是一个多维数组。
“交换维度”指的是改变这些维度(axes)的顺序或位置。
比如一个形状为
[batch, channel, height, width] = [32, 3, 224, 224]
的图像张量,如果模型要求输入为 [batch, height, width, channel],
我们就需要进行维度交换。
二、基础准备
让我们先创建一个简单的 3 维张量来演示:
import torchx = torch.randn(2, 3, 4) # shape: [2, 3, 4]
print(x.shape)
输出:
torch.Size([2, 3, 4])
表示一个形状为 (2, 3, 4) 的三维张量。
三、方法一:tensor.t() —— 仅限二维矩阵转置
如果你的张量是 二维矩阵(2D Tensor),可以使用 t():
x = torch.randn(2, 3)
print(x.shape) # torch.Size([2, 3])
print(x.t().shape) # torch.Size([3, 2])
📌 注意:
t() 只能用于二维张量(矩阵),否则会报错。
四、方法二:tensor.transpose(dim0, dim1) —— 交换两个指定维度
transpose /trænˈspoʊz/ 中文翻译为"调换"
transpose() 用于交换两个指定的维度,不改变其他维度的顺序。
示例:
x = torch.randn(2, 3, 4) # shape: [2, 3, 4]
y = x.transpose(1, 2) # 交换第1维和第2维
print(y.shape)
输出:
torch.Size([2, 4, 3])
也就是说:
原: [batch, channel, length]
新: [batch, length, channel]
⚠️ 注意:
transpose()不会拷贝数据,而是返回一个新的视图(view),节省内存。
有关拷贝的问题,可参考深入理解 Python 的 copy() 函数:浅拷贝与深拷贝详解_python .copy()函数-CSDN博客
五、方法三:tensor.permute(dims) —— 任意维度重排
permute /pərˈmjʊt/ 中文翻译为"重新排列"
permute() 是更通用的维度交换方式。
它允许你指定所有维度的新顺序。
✅ 示例:
x = torch.randn(2, 3, 4)
y = x.permute(1, 0, 2)
print(y.shape)
输出:
torch.Size([3, 2, 4])
📘 解释:
- 原始顺序是
[0, 1, 2] - 新顺序
[1, 0, 2]表示:- 第 0 维 → 原第 1 维
- 第 1 维 → 原第 0 维
- 第 2 维 → 保持不变
transpose() vs permute()
| 特性 | transpose() | permute() |
|---|---|---|
| 功能 | 交换两个维度 | 任意排列多个维度 |
| 参数 | 两个整数维度索引 | 维度索引序列 |
| 返回 | 新视图 | 新视图 |
| 使用场景 | 常用于2维或3维简单交换 | 用于复杂维度调整 |
举例:
# transpose 只能交换两个维度
x.transpose(0, 1)# permute 可以任意排列
x.permute(2, 0, 1)
六、方法四:tensor.view() 与 tensor.reshape() —— 改变形状(非严格意义上的交换)
view() 和 reshape() 是用于改变张量形状(Shape)的操作,
并不直接交换维度,但在实际中常与 permute() 连用。
✅ 示例:
x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1).contiguous().view(2, -1)
print(y.shape)
输出:
torch.Size([2, 12])
📘 解释:
- 先用
permute()调整维度顺序; - 再用
view()拉平成一个新的形状。
⚠️ 注意:
permute()后的张量在内存中可能不是连续存储的;- 因此通常需要
.contiguous()之后再调用view()。
