PyTorch 张量核心知识点
文章目录
- PyTorch 张量核心知识点
- 一、张量基础认知
- 1. 张量的定义
- 2. 张量的维度与形状
- 二、张量创建方法
- 1. 直接创建(基于已知数据)
- 2. 特殊值张量
- 3. 随机张量
- 4. 基于已有张量创建(形状匹配)
- 三、张量数据类型
- 1. 常见数据类型
- 2. 数据类型指定与转换
- 四、张量访问与取值
- 1. 索引访问(多维索引)
- 2. 切片访问(范围取值)
- 3. 单个元素提取(`item()`)
- 4. 掩码取值(布尔索引)
- 五、张量形状修改
- 1. 重塑(`reshape`/`view`)
- 2. 维度重排(`permute`/`transpose`)
- 3. 维度压缩与扩展(`squeeze`/`unsqueeze`)
- 4. 维度扩展(`expand`/`expand_as`)
- 六、张量运算
- 1. 基础算术运算
- 2. 广播机制(Broadcast)
- 3. 数学函数
- (1)三角函数
- (2)比较函数
- (3)统计函数
- 4. 矩阵运算
- 1)普通矩阵乘法(2 维)
- (2)批量矩阵乘法(3 维及以上)
- 5. 张量操作(拼接、堆叠、拆分)
- (1)拼接(`concat`)
- (2)堆叠(`stack`)
- (3)拆分(`split`/`chunk`)
- (4)展平(`flatten`)
- 七、其他常用操作
- 1. 克隆(`clone()`)
- 2. 脱离计算图(`detach()`)
- 八、核心重点总结
- 九、广播机制 “三步法” 总结
PyTorch 张量核心知识点
一、张量基础认知
1. 张量的定义
- 张量(Tensor)是 PyTorch 中数据运算的基本单元,本质是多维数组,用于存储和处理高维数据。
- PyTorch 神经网络的输入、权重、输出等均以张量形式存在,所有运算均基于张量进行。
2. 张量的维度与形状
- 维度:张量的 “阶数”,如 0 维(标量)、1 维(向量)、2 维(矩阵)、3 维及以上(高维张量)。
- 形状(shape/size):描述每个维度的元素个数,格式为
(dim1_len, dim2_len, ..., dimN_len)
。- 例:
torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
的形状为(2,2,2)
(2 个 2×2 矩阵)。
- 例:
- 查看形状:
tensor.shape
或tensor.size()
(两者等价)。
二、张量创建方法
1. 直接创建(基于已知数据)
import torch
# 0 维张量(标量)
t_scalar = torch.tensor(5, dtype=torch.float)
# 1 维张量(向量)
t_vec = torch.tensor([1, 2, 3])
# 3 维张量
t_3d = torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
2. 特殊值张量
函数 | 作用 | 示例 |
---|---|---|
torch.zeros(shape) | 创建全 0 张量 | torch.zeros(2, 3) → 2×3 全 0 |
torch.ones(shape) | 创建全 1 张量 | torch.ones(3, 4) → 3×4 全 1 |
torch.empty(shape) | 创建空张量(未初始化,值随机) | torch.empty(2, 2) |
torch.arange(n) | 创建 0 到 n-1 的连续整数张量(1 维) | torch.arange(6) → [0,1,2,3,4,5] |
注意:
empty
仅分配内存不初始化,速度快但值不可控;zeros/ones
会初始化值,更安全。
3. 随机张量
函数 | 作用 | 示例 |
---|---|---|
torch.rand(shape) | 0~1 均匀分布随机数 | torch.rand(2,3) |
torch.randn(shape) | 标准正态分布(均值 0,方差 1) | torch.randn(2,3) |
torch.randint(low, high, size) | [low, high) 整数随机数 | torch.randint(0,5, (2,3)) |
torch.normal(mean, std, size) | 自定义正态分布(均值 mean,标准差 std) | torch.normal(mean=2, std=1, size=(2,3)) |
- 固定随机种子(确保结果可复现):
torch.manual_seed(100)
(种子值可自定义)。
4. 基于已有张量创建(形状匹配)
基于某张量的形状创建新张量,避免重复写形状参数:
t = torch.rand(2, 3) # 原张量形状 (2,3)
t_empty_like = torch.empty_like(t) # 空张量,形状与 t 一致
t_zeros_like = torch.zeros_like(t) # 全 0 张量,形状与 t 一致
t_rand_like = torch.rand_like(t) # 0~1 随机张量,形状与 t 一致
三、张量数据类型
1. 常见数据类型
类型类别 | 具体类型 | 说明 |
---|---|---|
整数型 | torch.int /torch.int32 | 标准 32 位整数 |
torch.int64 /torch.long | 64 位整数(常用于索引) | |
torch.uint8 | 无符号 8 位整数(0~255) | |
浮点型 | torch.float /torch.float32 | 32 位浮点数(默认浮点类型) |
torch.float64 /torch.double | 64 位浮点数(精度更高) | |
布尔型 | torch.bool | 布尔值(True/False) |
2. 数据类型指定与转换
-
创建时指定:指定
dtype
参数t_int = torch.zeros(2,3, dtype=torch.int) t_float = torch.rand(2,3, dtype=torch.float64)
-
创建后转换:
-
方法 1:
tensor.to(dtype)
tensor.to(dtype) t = torch.zeros(2,3) # 默认 float32 t_uint8 = t.to(torch.uint8)
-
方法 2:简写方法(如
double()
/int()
/long()
)t_double = t.double() # 转 float64 t_long = t.long() # 转 int64
-
四、张量访问与取值
1. 索引访问(多维索引)
-
格式:
tensor[dim0_idx, dim1_idx, ..., dimN_idx]
,支持整数索引。t = torch.arange(24).reshape(2,3,4) # 形状 (2,3,4) print(t[0,1,2]) # 取第 0 个 3×4 矩阵的第 1 行第 2 列 → tensor(6)
2. 切片访问(范围取值)
-
格式:
tensor[dim0_slice, dim1_slice, ...]
,支持start:end:step
切片语法。print(t[:, 1, 1:3]) # 所有矩阵的第 1 行、第 1-2 列 → 形状 (2,2)
3. 单个元素提取(item()
)
-
仅适用于单元素张量(如标量或形状为
(1,)
的张量),返回 Python 原生类型。t_single = t[0,1,2] # 单元素张量 print(t_single.item()) # 6(Python 整数)
4. 掩码取值(布尔索引)
-
用布尔张量筛选满足条件的元素,常用于批量修改。
t = torch.randint(0,3, (5,5)) # 5×5 整数张量 mask = t == 0 # 布尔掩码:True 表示元素为 0 的位置 t[mask] = -1 # 将所有为 0 的元素改为 -1
五、张量形状修改
1. 重塑(reshape
/view
)
-
作用:改变张量形状,元素总数不变(各维度长度乘积需等于原总数)。
t = torch.arange(24) # 形状 (24,) t_234 = t.reshape(2,3,4) # 重塑为 (2,3,4) t_view = t.view(4,6) # 视图方式重塑为 (4,6)
-
区别:
reshape
:可能重新分配内存(若原张量内存不连续)。view
:仅创建视图(共享内存,不重新分配),仅适用于内存连续的张量。
-
便捷语法:用
-1
自动计算某维度长度(仅一个-1
有效)t_auto = t.reshape(2, -1, 4) # -1 自动计算为 3 → 形状 (2,3,4)
2. 维度重排(permute
/transpose
)
-
作用:改变维度的顺序,不改变元素值。
transpose(dim1, dim2)
:仅交换两个维度。permute(dim0, dim1, ...)
:任意重排所有维度。
t = torch.rand(3,4,5) # 原形状 (3,4,5) t_trans = t.transpose(1,2) # 交换 1、2 维 → 形状 (3,5,4) t_perm = t.permute(2,0,1) # 重排为 (5,3,4)
-
特殊:二维张量转置可直接用
tensor.T
t_mat = torch.rand(3,4) t_mat_T = t_mat.T # 转置为 (4,3)
3. 维度压缩与扩展(squeeze
/unsqueeze
)
-
squeeze(dim)
:删除长度为 1 的维度(不指定 dim 则删除所有长度为 1 的维度)。t = torch.rand(1,3,1,4) # 形状 (1,3,1,4) t_sq = t.squeeze() # 删除所有长度 1 维度 → (3,4) t_sq0 = t.squeeze(0) # 仅删除第 0 维 → (3,1,4)
-
unsqueeze(dim)
:在指定位置插入长度为 1 的维度。t = torch.rand(3,4) # 形状 (3,4) t_usq0 = t.unsqueeze(0) # 第 0 维插入 → (1,3,4) t_usq2 = t.unsqueeze(2) # 第 2 维插入 → (3,4,1)
4. 维度扩展(expand
/expand_as
)
-
作用:将长度为 1 的维度扩展为指定长度(浅表复制,共享内存,不新增元素)。
t = torch.arange(6).reshape(2,1,3) # 形状 (2,1,3) t_exp = t.expand(2,4,3) # 第 1 维从 1 扩展到 4 → (2,4,3) t_exp_auto = t.expand(-1,4,-1) # -1 表示保留原长度 → 同上
-
expand_as(tensor)
:扩展为目标张量的形状(需满足广播条件)。t_target = torch.rand(2,4,3) t_exp_as = t.expand_as(t_target) # 扩展为 (2,4,3)
六、张量运算
1. 基础算术运算
-
与标量运算:张量的每个元素与标量进行运算(+、-、×、/、**、%)。
t = torch.arange(6).reshape(2,3) # [[0,1,2],[3,4,5]] print(t + 2) # 所有元素加 2 print(t * 3) # 所有元素乘 3 print(t **2) # 所有元素平方
-
同形张量运算:形状完全相同的张量,对应元素逐一运算。
t2 = torch.tensor([[0,0,1],[2,1,0]]) print(t + t2) # 对应元素相加
2. 广播机制(Broadcast)
-
定义:自动扩展形状不同的张量,使它们可进行元素级运算(无需显式扩展)。
-
广播条件:两个张量从最右侧维度开始比较,每个维度满足 “长度相等” 或 “其中一个为 1”。
x = torch.randint(0,3, (2,3,1)) # 形状 (2,3,1) y = torch.randint(0,3, (3,2)) # 形状 (3,2) print(x + y) # 广播后 x 为 (2,3,2),y 为 (1,3,2) → 结果 (2,3,2)
3. 数学函数
(1)三角函数
-
需先将角度转为弧度(
torch.deg2rad()
)。angles = torch.tensor([30, 60]) radians = torch.deg2rad(angles) sin_vals = torch.sin(radians) # 正弦值
(2)比较函数
函数 | 作用 |
---|---|
torch.eq(t1, t2) | 逐元素判断是否相等(返回布尔张量) |
torch.equal(t1, t2) | 判断两个张量完全相同(形状 + 值) |
torch.allclose(t1, t2, atol=ε) | 判断数值近似相等(atol 为允许误差) |
(3)统计函数
函数 | 作用 | 示例 |
---|---|---|
tensor.max(dim) | 沿指定维度求最大值(返回值 + 索引) | t.max(dim=0) → 按列求最大 |
tensor.min(dim) | 沿指定维度求最小值 | t.min(dim=1) → 按行求最小 |
tensor.mean(dim) | 沿指定维度求均值 | t.mean(dim=0) → 按列求均值 |
tensor.var(dim) | 沿指定维度求方差 | t.var(dim=1) → 按行求方差 |
tensor.std(dim) | 沿指定维度求标准差(方差开根号) | t.std(dim=0) → 按列求标准差 |
-
限制值范围:
t = torch.randn(5,5) # 正态分布随机数 t_clamp = torch.clamp(t, -0.1, 0.1) # 限制在 [-0.1, 0.1]
4. 矩阵运算
-
1)普通矩阵乘法(2 维)
- 要求:前一个张量的最后一维长度 = 后一个张量的倒数第二维长度。
- 函数:
torch.matmul(t1, t2)
或简写t1 @ t2
、torch.mm(t1, t2)
。
A = torch.randint(1,4, (2,3)) # 2×3 矩阵 B = torch.randint(1,4, (3,4)) # 3×4 矩阵 C = A @ B # 结果为 2×4 矩阵
-
区别:
matmul
支持广播,mm
仅支持 2 维张量且不广播。
(2)批量矩阵乘法(3 维及以上)
-
作用:对批量的矩阵逐一相乘(前 N-2 维为批量维度,最后 2 维为矩阵维度)。
-
函数:
torch.bmm(t1, t2)
。A = torch.randint(1,4, (5,2,3)) # 5 个 2×3 矩阵 B = torch.randint(1,4, (5,3,4)) # 5 个 3×4 矩阵 C = torch.bmm(A, B) # 结果为 5 个 2×4 矩阵 → 形状 (5,2,4)
5. 张量操作(拼接、堆叠、拆分)
(1)拼接(concat
)
-
作用:将多个张量沿已有维度拼接(不新增维度)。
-
要求:除拼接维度外,其他维度形状必须一致。
A = torch.rand(2,3,4) B = torch.rand(2,3,4) C_dim0 = torch.concat([A,B], dim=0) # 沿第 0 维拼接 → (4,3,4) C_dim1 = torch.concat([A,B], dim=1) # 沿第 1 维拼接 → (2,6,4)
(2)堆叠(stack
)
-
作用:将多个张量沿新增维度堆叠(会新增维度)。
-
要求:所有张量形状必须完全一致。
A = torch.rand(3,4) B = torch.rand(3,4) C_dim0 = torch.stack([A,B], dim=0) # 新增第 0 维 → (2,3,4) C_dim2 = torch.stack([A,B], dim=2) # 新增第 2 维 → (3,4,2)
(3)拆分(split
/chunk
)
-
split(segment_len, dim)
:按 “每段长度” 拆分。 -
chunk(num_chunks, dim)
:按 “拆分块数” 拆分。t = torch.rand(3,6,5) # 形状 (3,6,5) # split:每段长度 2,沿第 1 维拆分 A1,B1,C1 = t.split(2, dim=1) # 每段形状 (3,2,5) # chunk:拆分为 3 块,沿第 1 维拆分 A2,B2,C2 = t.chunk(3, dim=1) # 每块形状 (3,2,5)
(4)展平(flatten
)
-
作用:将指定维度范围合并为一个维度。
-
格式:
torch.flatten(tensor, start_dim, end_dim)
(默认start_dim=0
,end_dim=-1
)。t = torch.rand(2,3,4,5) t_flat1 = t.flatten(start_dim=1) # 第 1-3 维展平 → (2, 60) t_flat2 = t.flatten(1,2) # 第 1-2 维展平 → (2, 12, 5)
七、其他常用操作
1. 克隆(clone()
)
-
作用:创建张量的深拷贝(新张量与原张量值相同,但内存独立)。
t = torch.arange(10) t_clone = t.clone() t[0] = 100 # 修改原张量,克隆张量不受影响 print(t_clone) # 仍为 [0,1,2,...,9]
2. 脱离计算图(detach()
)
-
作用:创建张量的浅拷贝,脱离当前计算图(仅用于推理,不参与梯度计算)。
t = torch.arange(10, requires_grad=True) t_detach = t.detach() # 脱离计算图,无梯度
八、核心重点总结
- 形状匹配:所有张量运算需确保形状兼容(广播机制可简化部分场景)。
- 维度操作:
reshape
(重塑)、permute
(重排)、squeeze/unsqueeze
(增减维度)是高频操作。 - 矩阵乘法:
matmul
(支持广播)、bmm
(批量矩阵)需注意维度匹配。 - 内存效率:
view
、expand
共享内存,reshape
、clone
可能重新分配内存,按需选择。
九、广播机制 “三步法” 总结
遇到任何广播场景,都可以按以下步骤判断:
- 补维度:给维度数少的张量左侧补 1,直到两个张量维度数一致;
- 比维度:从最右侧维度开始,逐维对比,每个维度需满足 “相等” 或 “其中一个为 1”;
- 扩维度:将所有 “长度为 1” 的维度,扩展为另一个张量对应维度的长度,最终两个张量形状完全一致。
通过以上例子,能覆盖 90% 以上的广播场景,核心是 “右对齐、补 1 维、判规则、扩长度”,多练两次就能快速判断~