PyTorch 中unsqueeze(-1)用法
unsqueeze(-1)
是 PyTorch 中的一个张量操作,用于在指定维度上增加一个长度为1的维度(即扩展维度)。具体解析如下:
功能说明
-
作用位置
-1
表示在张量的最后一个维度后添加新维度。
(等价于dim=len(tensor.shape)
) -
输入输出对比
- 假设原张量
train_X
形状为(N,)
(一维向量) - 执行后形状变为
(N, 1)
(二维矩阵)
- 假设原张量
-
典型用途
- 适配神经网络层输入要求(如全连接层需要二维输入)
- 广播机制(Broadcasting)前的维度对齐
- 处理单通道数据(如时间序列、灰度图像)
示例演示
import torch# 原始数据(一维张量)
data = torch.tensor([1, 2, 3]) # shape: (3,)# 添加维度后
expanded = data.unsqueeze(-1) # shape: (3, 1)
print(expanded)
输出:
tensor([[1],[2],[3]])
其他等价写法
unsqueeze(1)
:当输入为一维时效果与unsqueeze(-1)
相同data[:, None]
:Python 切片语法实现相同功能