Pytorch基础操作
面试的时候,PhD看我简历上面写了”熟悉pytorch框架“,然后就猛猛提问了有关于tensor切片的问题…当然是没答上来,因此在这里整理一下pytorch的一些基础编程语法,常看常新
PyTorch基础操作全解
一、张量初始化
PyTorch的核心数据结构是torch.Tensor
,初始化方法灵活多样:
1. 基础初始化
import torch# 未初始化张量(内存中可能存在随机值)a = torch.empty(3, 2) # 3x2的未初始化矩阵# 均匀分布随机数 [0,1)b = torch.rand(2, 3) # 2x3随机矩阵# 全零矩阵(显式指定类型)c = torch.zeros(4, 3, dtype=torch.long) # 4x3的长整型零矩阵# 从列表创建d = torch.tensor([5.5, 3]) # 直接数值初始化
2. 基于已有张量的初始化
x = torch.rand(2, 2)# 继承原有张量属性(形状/设备)new_tensor = x.new_ones(3, 3, dtype=torch.double) # 3x3全1矩阵,继承x的设备# 正态分布(继承形状)like_tensor = torch.randn_like(x, dtype=torch.float) # 与x同形的正态分布
二、张量属性与运算
1. 关键属性
print(x.dtype) # 数据类型 torch.float32print(x.device) # 存储设备 cpu/cuda:0print(x.shape) # 等价于x.size()
2. 基本运算(加法/矩阵乘法/张量形状操作)
# 加法(三种等价方式)result1 = a + bresult2 = torch.add(a, b)a.add_(b) # in-place操作(会修改a)# 矩阵乘法mat1 = torch.randn(2, 3)mat2 = torch.randn(3, 2)product = torch.mm(mat1, mat2) # 2x2结果矩阵# 形状操作reshaped = x.view(4) # 展平为1D(必须连续内存)resized = x.reshape(-1) # 自动推断维度(处理非连续内存)
view(-1)自动推断维度
# 输入序列 (batch=2, seq_len=5, features=10)seq_data = torch.randn(2, 5, 10)# 转换为(batch*seq_len, features)reshaped = seq_data.view(-1, 10) # 形状[10, 10]print(reshaped.shape) # torch.Size([10, 10])
3. 类型转换
float_tensor = x.to(torch.float64) # 显式转换类型gpu_tensor = x.cuda() # 转移至GPU
三、高级切片与索引
1. 三维张量切片(面试题解析)
假设有张量 tensor = torch.randn(5, 4, 6)
:
• 第一个维度取第一个元素:tensor[0]
(等价于tensor[0, :, :]
)
• 第二个维度取全部元素::
或 ...
• 第三个维度取奇数索引元素:1::2
(从索引1开始,步长2)
完整解:
result = tensor[0, :, 1::2] # shape变为 (4, 3)
2. 高级索引技巧
# 布尔掩码mask = tensor > 0.5selected = tensor[mask]# 组合索引indices = torch.tensor([0, 2])partial = tensor[:, indices, :]
四、与NumPy的互操作
1. 转换机制
# Tensor -> ndarraynumpy_array = tensor.numpy() # CPU张量直接转换# ndarray -> Tensortorch_tensor = torch.from_numpy(numpy_array)# GPU张量转换cpu_tensor = gpu_tensor.cpu()numpy_from_gpu = cpu_tensor.numpy()
2. 内存共享特性:底层其实共享一套内存
a = torch.ones(3)b = a.numpy()a.add_(1) # 修改张量print(b) # [2., 2., 2.] 同步变化
五、扩展知识
一、自动求导机制
- 核心概念
PyTorch使用动态计算图实现自动微分:
• requires_grad:标记需要跟踪梯度的张量
• 计算图:记录张量间的运算关系(正向传播)
• backward():反向传播计算梯度
• grad属性:存储梯度值(默认会累积)
- 基础示例
x = torch.tensor(2., requires_grad=True)y = x**2 + 3*x # 计算图建立y.backward() # 反向传播print(x.grad) # 输出:tensor(7.) # 导数计算:dy/dx = 2x + 3 → 2*2 + 3 = 7
- 梯度累积特性
第二次反向传播前必须清除梯度
# 第二次反向传播前必须清除梯度x.grad.zero_() # 梯度清零y = x**3y.backward()print(x.grad) # 3x² → 3*(2)^2 = 12
- 非标量梯度处理
# 多输出系统需要指定gradient参数x = torch.randn(3, requires_grad=True)y = x * 2v = torch.tensor([0.1, 1.0, 0.001], dtype=torch.float)y.backward(v) # 加权反向传播print(x.grad) # 输出:tensor([0.2000, 2.0000, 0.0020])
- 梯度控制上下文
# 禁用梯度计算(节约内存)with torch.no_grad():inference = x * 2 # 不会记录计算图print(inference.requires_grad) # False# 临时分离张量detached_x = x.detach() # 创建无需梯度的副本
二、张量拼接操作
- 维度拼接 (torch.cat)
必须保证维度匹配,dim=0(第一个维度拼接),dim=1(第二个维度拼接)
# 在现有维度上拼接a = torch.randn(2, 3)b = torch.randn(4, 3)concat_0 = torch.cat([a, b], dim=0) # 形状(6,3)concat_1 = torch.cat([a, a], dim=1) # 形状(2,6)
- 新增维度拼接 (torch.stack)
新增维度拼接使用torch.stack,新增第零个维度维度拼接torch.stack([c, d], dim=0),新增最后一个维度torch.stack([c.T, d.T], dim=2)
# 创建新维度c = torch.randn(3, 4)d = torch.randn(3, 4)stack_0 = torch.stack([c, d], dim=0) # 形状(2,3,4)stack_2 = torch.stack([c.T, d.T], dim=2) # 形状(3,4,2)
- 拼接规则验证
try:# 维度不匹配报错invalid = torch.cat([a, b], dim=1) except RuntimeError as e:print(f"Error: {e}") # 非拼接维度尺寸不一致
三、广播机制详解——两个张量维度不同时,自动对齐维度(复制出来直接补充)
- 广播规则
当两个张量维度不同时:
-
从右向左对齐维度
-
维度相容条件:
• 维度大小相等
• 其中一个维度为1
-
自动扩展:将尺寸为1的维度复制到匹配对方
-
典型示例
# 案例1:向量+标量a = torch.tensor([1, 2, 3])b = torch.tensor(5)print(a + b) # [6,7,8]# 案例2:矩阵+向量matrix = torch.ones(2, 3) # (2,3)vector = torch.arange(3) # (3,) → (1,3) → (2,3)print(matrix + vector)# [[0,1,2],# [0,1,2]] + [[1,1,1],# [1,1,1]] = [[1,2,3],# [1,2,3]]# 案例3:三维广播tensor_3d = torch.ones(4, 3, 2)tensor_2d = torch.tensor([[0], [1], [2]]) # (3,1) → (4,3,2)result = tensor_3d + tensor_2dprint(result.shape) # (4,3,2)
- 广播失败案例
try:a = torch.ones(3, 4)b = torch.ones(2, 5)c = a + bexcept RuntimeError as e:print(f"Error: {e}") # 无法广播
四、综合应用示例
- 梯度控制与广播的结合
with torch.no_grad():base = torch.ones(2, 2)delta = torch.tensor([1., 2.]) # 广播为(2,2)modified = base * deltaprint(modified) # 无梯度跟踪# 输出:# tensor([[1., 2.],# [1., 2.]])
- 拼接与自动求导
x = torch.tensor([1., 2.], requires_grad=True)y = torch.cat([x, x**2], dim=0) # 拼接成[1,2,1,4]loss = y.sum() # 1+2+1+4 = 8loss.backward()print(x.grad) # [1+2x, 2+0] → [1+2*1=3, 2+0=2]# 输出:tensor([3., 2.])