当前位置: 首页 > news >正文

Tensor :核心概念、常用函数与避坑指南

一、Tensor 核心概念:从定义到本质

Tensor 是 PyTorch 中最基础的数据结构,可理解为“多维数组”,是深度学习中存储数据(输入/权重/梯度)、执行计算的核心载体。其本质与 NumPy 数组类似,但支持 GPU 加速和自动求导,这是区别于普通数组的关键。

1. 1. Tensor 的维度与形状

  • 维度(dim):Tensor 的“阶数”,1 维对应向量、2 维对应矩阵、3 维及以上对应高维张量(如 [batch, channel, height, width] 是 4 维图像张量)。
  • 形状(shape):用元组表示各维度的元素数量,是 Tensor 最核心的属性之一(决定数据的组织形式)。
    示例:
    import torch
    # 1维(shape=(3,))
    t1 = torch.tensor([1,2,3])
    # 2维(shape=(2,3):2行3列)
    t2 = torch.tensor([[1,2,3],[4,5,6]])
    # 3维(shape=(2,2,3):2个2行3列的矩阵)
    t3 = torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])print(t2.shape)  # 输出: torch.Size([2, 3])
    print(t3.dim())  # 输出: 3
    

1. 2. Tensor 的数据类型(dtype)

PyTorch 支持多种数据类型,必须显式指定或匹配(不同 dtype 无法直接计算),常用类型如下:

dtype 类型用途说明注意事项
torch.float32默认浮点类型(占4字节)深度学习中最常用(平衡精度与速度)
torch.float64高精度浮点(占8字节)需手动指定(如 dtype=torch.float64
torch.int32/int64整数类型(分别占4/8字节)标签/索引常用 int64(默认整数类型)
torch.bool布尔类型(True/False)用于掩码(mask)操作

易错点:不同 dtype 张量运算会报错,需用 .to(dtype) 统一类型:

a = torch.tensor([1,2], dtype=torch.float32)
b = torch.tensor([3,4], dtype=torch.int64)
# a + b  # 直接运算报错
a + b.to(torch.float32)  # 正确:将b转为float32

1. 3. Tensor 的设备(device):CPU 与 GPU

Tensor 可存储在 CPU 或 GPU(需支持 CUDA)上,只有同设备的 Tensor 才能运算,这是新手最易踩的坑之一。

  • 查看设备:tensor.device
  • 切换设备:tensor.to(device)device=torch.device("cpu")"cuda:0"

示例:

# 1. 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # 输出: cuda(若有GPU)或 cpu# 2. 创建时指定设备
t = torch.tensor([1,2,3], device=device)# 3. 已创建的Tensor切换设备(返回新Tensor,原Tensor不变)
t_cpu = t.to("cpu")
print(t.device)    # 输出: cuda:0(原Tensor仍在GPU)
print(t_cpu.device)# 输出: cpu

易错点

  • 误以为 tensor.to(device) 会“原地修改”原 Tensor,实际会返回新 Tensor,需重新赋值(如 t = t.to(device));
  • GPU 张量无法直接用 printnumpy() 查看(需先转 CPU:t.cpu().numpy())。

1. 4. 内存布局:shape 与 stride(步幅)

Tensor 无论多少维,在内存中均以一维连续存储(默认“行优先”,即先存完一行再存下一行,对应 C 语言风格)。stride(步幅)是连接“多维索引”与“一维内存”的关键。

  • stride 定义:元组,stride[k] 表示“在第 k 维移动一个元素时,需在内存中跳过的元素个数”。
  • 计算规律(行优先)stride[k] = 第k+1维大小 × 第k+2维大小 × ... × 最后一维大小(即当前维度之后所有维度的乘积)。

示例(对应前文案例):

x = torch.tensor([[1,2,3],[4,5,6]])  # shape=(2,3)
print(x.stride())  # 输出: (3,1) → dim0(行)步幅=3,dim1(列)步幅=1
# 解释:跨行(dim0)需跳过1整行(3个元素),跨列(dim1)直接取下一个元素y = torch.tensor([[1,2],[3,4],[5,6]])  # shape=(3,2)
print(y.stride())  # 输出: (2,1) → dim0步幅=2(1行2个元素)

核心作用:通过 shapestride 定位元素。例如 x[1,2](第1行第2列,值为6)的内存索引计算:
内存索引 = 1×stride[0] + 2×stride[1] = 1×3 + 2×1 = 5 → 对应内存中第5个元素(0开始计数),即6。

1. 5. 自动求导开关:requires_grad

requires_grad 是 Tensor 支持自动求导(反向传播)的核心属性,默认值为 False(普通数据张量)。若需计算该 Tensor 的梯度(如模型权重),需设为 True

  • 查看/设置:tensor.requires_grad(查看)、tensor.requires_grad_(True)(原地设置,下划线表示“原地操作”)。
  • 梯度存储:当执行 loss.backward() 后,梯度会存在 tensor.grad 中(仅 requires_grad=True 的 Tensor 有 grad 属性)。

示例:

# 1. 创建可求导的Tensor
x = torch.tensor([2.0], requires_grad=True)  # 注意:求导需用浮点类型(int无法求导)
y = x ** 2  # 计算图:y = x²# 2. 反向传播(求y对x的梯度)
y.backward()  # 等价于 dy/dx = 2x# 3. 查看梯度
print(x.grad)  # 输出: tensor([4.]) → 当x=2时,dy/dx=4,正确

易错点

  • 整数类型(如 int32)无法求导,需先转为浮点(float32);
  • 若只需“冻结部分参数”(如迁移学习),需将不需要更新的 Tensor 设为 requires_grad=False,并在优化器中过滤(filter(lambda p: p.requires_grad, model.parameters()))。

二、Tensor 常用操作:创建、变形、索引与计算

按“数据处理流程”分类,覆盖 90% 实战场景。

2. 1. Tensor 创建函数

避免手动写 torch.tensor([...]),用以下函数高效创建批量/特殊张量:

函数功能说明示例
torch.empty(shape)创建未初始化的张量(值随机,速度最快)torch.empty(2,3) → shape=(2,3)
torch.zeros(shape)创建全0张量torch.zeros(3,1) → 3行1列全0
torch.ones(shape)创建全1张量torch.ones(2,2,2) → 3维全1
torch.arange(start,end,step)生成连续整数(左闭右开)torch.arange(0,10,2) → [0,2,4,6,8]
torch.linspace(start,end,n)生成n个均匀分布的数(左闭右闭)torch.linspace(0,1,5) → [0,0.25,0.5,0.75,1]
torch.randn(shape)生成标准正态分布(μ=0,σ=1)张量torch.randn(2,3) → 2行3列随机数
torch.rand(shape)生成[0,1)均匀分布张量torch.rand(1,5) → 1行5列[0,1)数
torch.from_numpy(np_arr)从NumPy数组转为Tensor(共享内存)np_arr = np.array([1,2])torch.from_numpy(np_arr)

关键区别

  • torch.tensor():从已有数据创建(复制数据);
  • torch.from_numpy():从 NumPy 数组创建(共享内存,修改一个会影响另一个)。

2. 2. 形状变形:view() 与 reshape()(核心高频)

变形操作不改变 Tensor 的元素数量和值,只改变维度组织形式,需满足“原形状元素总数 = 新形状元素总数”(否则报错)。

函数功能说明适用场景注意事项
tensor.view(new_shape)基于原 Tensor 的 stride 变形(不复制内存)原 Tensor 是“连续内存”(tensor.is_contiguous() 为 True)若原 Tensor 不连续(如转置后),会报错,需先调用 tensor.contiguous()
tensor.reshape(new_shape)智能变形(优先不复制,不连续则自动复制)所有场景(推荐新手用,兼容性更强)本质是“contiguous() + view()”的封装,无需手动处理连续性

示例:

t = torch.arange(0,6)  # shape=(6,),元素总数=6
print(t.view(2,3))     # 变形为(2,3) → 正确(2×3=6)
print(t.reshape(3,2))  # 变形为(3,2) → 正确(3×2=6)# 转置后(非连续)的变形
t_t = t.view(2,3).t()  # 转置为(3,2),此时 t_t.is_contiguous() → False
# t_t.view(6,)  # 报错(非连续)
t_t.reshape(6,)        # 正确(自动处理连续性)

拓展变形函数

  • tensor.unsqueeze(dim):在指定维度插入一个“1维”(如 (2,3)unsqueeze(1)(2,1,3));
  • tensor.squeeze(dim):删除指定维度的“1维”(若该维度大小为1,如 (2,1,3)squeeze(1)(2,3));
  • tensor.transpose(dim0, dim1):交换两个维度(如 (2,3,4)transpose(1,2)(2,4,3));
  • tensor.permute(dims):重排所有维度(如 (batch, channel, H, W)permute(0,2,3,1)(batch, H, W, channel),适配图像显示)。

2. 3. 索引与切片:提取部分元素

与 Python 列表/NumPy 索引逻辑一致,但支持“多维同时索引”,是数据筛选的核心操作。

基础索引(单元素/范围)
t = torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])  # shape=(2,2,3)# 1. 单元素索引(3维:dim0, dim1, dim2)
print(t[0,1,2])  # 输出: tensor(6) → 第0个矩阵、第1行、第2列# 2. 范围切片(用 : 表示“所有元素”,a:b 表示“从a到b-1”)
print(t[0, :, 1:3])  # 输出: [[2,3],[5,6]] → 第0个矩阵、所有行、第1-2列
print(t[:, 1, :])    # 输出: [[4,5,6],[10,11,12]] → 所有矩阵、第1行、所有列
高级索引(掩码/整数索引)
  • 掩码索引(masked_select):用布尔张量筛选“True 位置”的元素(返回1维张量);
  • 整数索引(index_select):用整数张量筛选“指定索引”的元素(需指定维度)。

示例:

t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])  # shape=(3,3)# 1. 掩码索引:筛选大于5的元素
mask = t > 5  # 布尔张量:[[False,False,False],[False,False,True],[True,True,True]]
print(t.masked_select(mask))  # 输出: tensor([6,7,8,9]) → 1维# 2. 整数索引:在dim0(行)上筛选第0、2行
idx = torch.tensor([0,2])
print(t.index_select(0, idx))  # 输出: [[1,2,3],[7,8,9]] → shape=(2,3)

易错点:高级索引(masked_select/index_select)会复制内存(返回新张量,与原张量无关联),而基础切片(:)是“视图”(不复制内存,修改切片会影响原张量)。

2. 4. 广播机制(Broadcasting):不同形状的张量运算

当两个 Tensor 形状不完全一致时,PyTorch 会自动触发“广播”,将它们调整为相同形状后再运算(避免手动扩展维度,简化代码)。

广播的 2 个规则(必须同时满足)
  1. 维度兼容:从“最后一个维度”开始对比,每个维度的大小要么相等,要么有一个为1
  2. 维度补全:若维度数不同,在“形状较短的 Tensor”前面补1,直到维度数一致。
示例(符合广播的场景)
# 场景1:维度数相同,部分维度为1
a = torch.ones(2,3)    # shape=(2,3)
b = torch.tensor([[1],[2]])  # shape=(2,1)
# 广播后:a保持(2,3),b扩展为(2,3)(每列重复[1]或[2])
print(a + b)  # 输出: [[2,2,2],[3,3,3]]# 场景2:维度数不同,补1后兼容
c = torch.ones(3)      # shape=(3,) → 补1后为(1,3)
d = torch.tensor([[1],[2]])  # shape=(2,1)
# 广播后:c扩展为(2,3),d扩展为(2,3)
print(c + d)  # 输出: [[2,2,2],[3,3,3]]
不符合广播的场景(报错)
e = torch.ones(2,3)  # shape=(2,3)
f = torch.ones(2,4)  # shape=(2,4)
# e + f  # 报错:最后一个维度3≠4,不满足规则1

避坑建议:不确定是否能广播时,用 torch.broadcast_tensors(a,b) 验证(返回广播后的张量,无报错则兼容)。

2. 5. 常用计算函数(按功能分类)

1. 逐元素运算(Element-wise)

对 Tensor 每个元素单独运算,输入输出形状相同:

  • 算术:torch.add(a,b)(a+b)、torch.sub(a,b)(a-b)、torch.mul(a,b)(元素乘)、torch.div(a,b)(元素除)
  • 激活函数:torch.relu(tensor)torch.sigmoid(tensor)torch.tanh(tensor)(均为逐元素计算)

示例:

a = torch.tensor([[1,2],[3,4]])
b = torch.tensor([[5,6],[7,8]])
print(torch.mul(a,b))  # 元素乘:[[5,12],[21,32]]
print(torch.relu(torch.tensor([[-1,2],[-3,4]])))  # relu:[[0,2],[0,4]]
2. 归约运算(Reduction)

对指定维度“压缩求和/求均值”等,会改变 Tensor 形状(需指定 dim,否则对所有元素计算):

函数功能关键参数说明
torch.sum(tensor, dim)按维度求和dim=0(按列求和),dim=1(按行求和);keepdim=True(保留原维度数)
torch.mean(tensor, dim)按维度求均值仅支持浮点型 Tensor(int 需先转 float)
torch.max(tensor, dim)按维度求最大值(返回值+索引)返回元组 (max_value, max_index)

示例:

t = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
# 按行求和(dim=1),保留原维度(shape从(2,3)→(2,1))
print(torch.sum(t, dim=1, keepdim=True))  # 输出: [[6.],[15.]]
# 按列求最大值,返回值和索引
max_val, max_idx = torch.max(t, dim=0)
print(max_val)  # 输出: tensor([4.,5.,6.])
print(max_idx)  # 输出: tensor([1,1,1])

易错点mean 函数无法直接用于整数 Tensor,需先通过 .float() 转换类型。

3. 矩阵运算(Matrix Operation)

区别于“元素乘”,需满足矩阵乘法规则(前一个 Tensor 的列数 = 后一个 Tensor 的行数):

  • torch.mm(a, b):仅支持 2 维 Tensor 矩阵乘法(a 是 (m,n),b 是 (n,p),输出 (m,p));
  • torch.matmul(a, b):支持高维 Tensor 矩阵乘法(自动忽略前导维度,仅对最后两个维度做矩阵乘);
  • torch.bmm(a, b):批量矩阵乘法(a 是 (batch, m,n),b 是 (batch, n,p),输出 (batch, m,p))。

示例:

# 2维矩阵乘法(mm)
a = torch.tensor([[1,2],[3,4]])  # (2,2)
b = torch.tensor([[5],[6]])      # (2,1)
print(torch.mm(a, b))  # 输出: [[17],[39]](1*5+2*6=17;3*5+4*6=39)# 批量矩阵乘法(bmm)
batch_a = torch.randn(3, 2, 4)  # 3个(2,4)矩阵
batch_b = torch.randn(3, 4, 5)  # 3个(4,5)矩阵
print(torch.bmm(batch_a, batch_b).shape)  # 输出: torch.Size([3,2,5])

三、Tensor 核心避坑指南(新手必看)

3. 1. 原地操作(In-place Operation)的“隐形坑”

  • 定义:以下划线 _ 结尾的函数(如 tensor.add_()tensor.requires_grad_()),会直接修改原 Tensor 的值/属性,不返回新 Tensor。
  • 风险场景
    1. 若 Tensor 参与计算图(requires_grad=True),原地操作可能破坏计算图,导致反向传播报错;
    2. 误将原地操作结果赋值(如 t = t.add_(1)),实际 add_ 已修改原 t,赋值无意义且易混淆。

正确做法

  • 普通数据处理:可用原地操作(如 t.zero_() 清空张量);
  • 模型训练(涉及求导):优先用非原地操作(如 t = t.add(1))。

3. 2. “视图(View)”与“副本(Copy)”的区别

Tensor 操作分为“不复制内存”和“复制内存”两类,混淆会导致数据修改异常:

类型操作示例特点(是否复制内存)修改影响
视图(View)基础切片(t[:,1])、view()reshape()(原张量连续时)不复制,共享内存修改视图会同步修改原张量
副本(Copy)tensor.clone()masked_select()index_select()reshape()(原张量不连续时)复制,独立内存修改副本不影响原张量

示例(视图的共享内存问题):

t = torch.tensor([[1,2],[3,4]])
view_t = t[:, 0]  # 视图(第0列:[1,3])
view_t[0] = 100   # 修改视图
print(t)  # 输出: [[100,2],[3,4]] → 原张量被同步修改!

避坑建议:若需“独立修改”,主动用 tensor.clone() 创建副本(如 copy_t = t[:,0].clone())。

3. 3. 求导相关的 3 个关键错误

  1. 整数 Tensor 无法求导
    原因:梯度计算基于浮点运算,整数类型无“小数梯度”。
    解决:创建时指定浮点 dtype(如 torch.tensor([2.0], requires_grad=True))。

  2. detach() 后的 Tensor 无法求导
    场景:用 tensor.detach() 分离计算图(常用于“冻结特征”),分离后的 Tensor requires_grad=False
    错误:对 detach() 后的 Tensor 执行 backward()
    解决:若需重新求导,用 detach_tensor.requires_grad_(True) 手动开启(但会脱离原计算图)。

  3. 多次 backward() 需清空梯度
    场景:模型训练中,若未清空前一次的梯度(optimizer.zero_grad()),多次 backward() 会导致梯度累加。
    后果:梯度异常,模型训练发散。
    解决:每次反向传播前,调用 optimizer.zero_grad()tensor.grad.zero_()

3. 4. 设备不匹配的“高频报错”

错误信息:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

  • 原因:参与运算的 Tensor 分别在 CPU 和 GPU 上。
  • 解决步骤:
    1. print(tensor.device) 检查所有 Tensor 的设备;
    2. 统一设备(推荐将所有 Tensor 移到 GPU,如 t = t.to(device)device 提前定义为 cudacpu);
    3. 注意:模型参数(model.parameters())也需移到对应设备(model = model.to(device))。

四、实战总结:Tensor 操作流程图

按“数据处理→模型训练→结果输出”的流程,梳理核心操作链路:

  1. 数据加载:用 torch.from_numpy()/torch.tensor() 导入数据 → 转浮点 dtype(t = t.float());
  2. 设备迁移:定义 device → 数据和模型移到设备(t = t.to(device)model = model.to(device));
  3. 前向计算:执行张量变形(reshape()/permute())、算术/矩阵运算 → 得到预测结果;
  4. 反向传播:计算损失 → loss.backward()(自动求梯度) → 优化器更新参数(optimizer.step());
  5. 结果处理:GPU 张量转 CPU(t = t.cpu()) → 转 NumPy(t.numpy()) → 可视化/保存。
http://www.dtcms.com/a/388803.html

相关文章:

  • 机器学习实战·第四章 训练模型(1)
  • 一次因表单默认提交导致的白屏排查记录
  • Linux:io_uring
  • 《第九课——C语言判断:从Java的“文明裁决“到C的“原始决斗“——if/else的生死擂台与switch的轮盘赌局》
  • 学习日报|Spring 全局异常与自定义异常拦截器执行顺序问题及解决
  • Spring Boot 参数处理
  • Debian系统基本介绍:新手入门指南
  • Spring Security 框架
  • Qt QPercentBarSeries详解
  • RTT操作系统(3)
  • DNS服务管理
  • IDA Pro配置与笔记
  • 虚函数表在单继承与多继承中的实现机制
  • 矿石生成(1)
  • Linux 线程的概念
  • Unity学习之资源管理(Resources、AssetDatabase、AssetBundle、Addressable)
  • LG P5138 fibonacci Solution
  • 删除UCPD监控服务或者监控驱动
  • 日语学习-日语知识点小记-构建基础-JLPT-N3阶段(33):文法運用第10回1+(考え方14)
  • 向量技术研究报告:从数学基础到AI革命的支柱
  • 802.1x和802.1Q之间关联和作用
  • 基于大模型多模态的人体体型评估:从“尺码测量”到“视觉-感受”范式
  • 更符合人类偏好的具身导航!HALO:面向机器人导航的人类偏好对齐离线奖励学习
  • Transformer多头注意力机制
  • git 分支 error: src refspec sit does not match any`
  • VN1640 CH5 I/O通道终极指南:【VN1630 I/O功能在电源电压时间精确度测试中的深度应用】
  • qt QHorizontalBarSeries详解
  • 半导体制造的芯片可靠性测试的全类别
  • MySQL 索引详解:原理、类型与优化实践
  • AI 重塑就业市场:哪些岗位将被替代?又会催生哪些新职业赛道?