深度学习之pytorch基本使用(二)
五、张量索引与形状操作
1.张量索引操作
以data = torch.randint(0,10,[4,5])
(示例输出:tensor([[0,7,6,5,9],[6,8,3,1,0],[6,3,8,7,3],[4,9,5,3,1]])
)为例,常见索引方式如下:
(1)简单行列索引:data[行索引]
取指定行,data[:,列索引]
取指定列。示例:
data = torch.randint(0,10,[4,5])print(data[0]) # 第0行 → tensor([0,7,6,5,9])
print(data[:,0]) # 第0列 → tensor([0,6,6,4])
(2)列表索引:data[[行索引列表], [列索引列表]]
取指定位置元素;data[[[行索引1],[行索引2]], [列索引1,列索引2]]
取指定行的指定列元素。示例:
print(data[[0,1],[1,2]]) # (0,1)、(1,2)元素 → tensor([7,3])
print(data[[[0],[1]],[1,2]]) # 0、1行的1、2列 → tensor([[7,6],[8,3]])
(3)范围索引:data[行范围, 列范围]
,:
表示所有元素,start:end
表示[start, end)
区间。示例:
print(data[:3, :2]) # 前3行前2列 → tensor([[0,7],[6,8],[6,3]])
print(data[2:, :2]) # 第2行到最后行的前2列 → tensor([[6,3],[4,9]])
(4)布尔索引
:data[布尔条件]
,筛选满足条件的行或列。示例:
print(data[data[:,2] > 5]) # 第3列大于5的行 → tensor([[0,7,6,5,9],[6,3,8,7,3]])
print(data[:, data[1] > 5]) # 第2行大于5的列 → tensor([[0,7],[6,8],[6,3],[4,9]])
(5)多维索引:对高维张量(如 3 维[3,4,5]
),按data[轴0索引, 轴1索引, 轴2索引]
取值。示例:
data = torch.randint(0,10,[3,4,5])
print(data[0, :, :]) # 轴0第0个元素 → (4,5)形状张量
print(data[:, 0, :]) # 轴1第0个元素 → (3,5)形状张量
print(data[:, :, 0]) # 轴2第0个元素 → (3,4)形状张量
2.张量形状操作
(1)reshape(shape):在保证元素总数不变的前提下,修改张量形状,不改变原张量。示例:
data = torch.tensor([[10,20,30],[40,50,60]]) # (2,3)
new_data = data.reshape(1,6) # (1,6) → tensor([[10,20,30,40,50,60]])
(2)squeeze()/unsqueeze(dim):
squeeze():删除所有形状为 1 的维度(降维),若指定dim
,仅删除该维度(需为 1)。
unsqueeze():在指定dim
处添加形状为 1 的维度(升维),dim=-1
表示最后一个维度。示例:
data = torch.tensor([1,2,3,4,5]) # (5)
data1 = data.unsqueeze(0) # (1,5) → tensor([[1,2,3,4,5]])
data2 = data.unsqueeze(-1) # (5,1) → tensor([[1],[2],[3],[4],[5]])
data3 = data2.squeeze() # (5) → 恢复原形状
(3)transpose(dim1,dim2)/permute(dim_list):
transpose(dim1,dim2):交换两个指定维度的形状,仅支持交换两个维度。
permute(dim_list):一次交换多个维度,按dim_list顺序重新排列维度。示例:
data = torch.randint(0,10,[3,4,5]) # (3,4,5)
data1 = torch.transpose(data,1,2) # 交换1和2维 → (3,5,4)
data2 = data.permute([1,2,0]) # 按[1,2,0]重排 → (4,5,3)
(4)view(shape)/contiguous():
view(shape):与reshape类似,但仅支持内存连续的张量(张量是否连续可通过data.is_contiguous()判断)。
contiguous():将非连续内存的张量转换为连续内存,以便使用view。示例:
data = torch.tensor([[10,20,30],[40,50,60]]) # 连续内存
data1 = torch.transpose(data,0,1) # 非连续内存 → is_contiguous()=False
# data1.view(2,3) # 报错,非连续内存
data2 = data1.contiguous() # 转为连续内存
data3 = data2.view(2,3) # 正常运行 → (2,3)形状
六、张量拼接与自动微分模块
1.张量拼接(torch.cat())
功能:将多个张量按指定维度(dim
)拼接,不改变维度数,要求除拼接维度外,其他维度形状一致。示例:
data1 = torch.randint(0,10,[1,2,3]) # (1,2,3)
data2 = torch.randint(0,10,[1,2,3]) # (1,2,3)# 按dim=0拼接 → (2,2,3)
new_data1 = torch.cat([data1,data2], dim=0)
# 按dim=1拼接 → (1,4,3)
new_data2 = torch.cat([data1,data2], dim=1)
# 按dim=2拼接 → (1,2,6)
new_data3 = torch.cat([data1,data2], dim=2)
2.自动微分模块(torch.autograd)
核心作用:实现反向传播,计算损失函数对模型参数的梯度,用于参数更新。
关键步骤:
(1)定义需计算梯度的张量:创建时设置requires_grad=True
(默认False
)。
(2)构建计算图:定义模型输出(如z = x*w + b
)和损失函数(如 MSE)。
(3)反向传播:调用loss.backward()
自动计算梯度,梯度存储在张量的grad
属性中。
示例:
def demo03_calc_grad():x = torch.ones(2,5)y = torch.zeros(2,3)#y(2,3) = x*w(2,5 * 5,3 = 2,3) +b(,3)w = torch.randn(5,3,requires_grad=True)b = torch.randn(3, requires_grad=True)z = torch.matmul(x,w) + bloss = torch.nn.MSELoss()loss = loss(z,y)loss.backward()print(w.grad)print(b.grad)if __name__ == "__main__":demo03_calc_grad()