当前位置: 首页 > news >正文

PyTorch张量运算、索引与自动微分详解

深度学习基础:PyTorch张量运算、索引与自动微分详解

前言

在上一篇文章中,学习了PyTorch张量的创建和基本操作。今天我们将深入学习PyTorch的进阶功能,包括张量的数学运算、各种索引操作、形状变换以及最重要的自动微分机制。这些内容是构建神经网络模型的核心基础,掌握它们对于后续的深度学习学习至关重要。

一、张量数学运算详解

1.1 统计运算函数

PyTorch提供了丰富的统计运算函数,这些函数在深度学习中经常用于数据分析和模型训练:

import torchdef dm01_operation_func():# 设置随机种子torch.manual_seed(2)# 创建测试数据data = torch.randint(0, 10, (2, 3), dtype=torch.float)print("data-->\n", data, data.dtype)# 求均值 mean()data1 = data.mean(dim=0)    # dim=0表示按第0维平均print("按列求平均-->\n", data1)data1 = data.mean(dim=1)    # dim=1表示按第1维平均print("按行求平均-->\n", data1)# 求和 sum()data2 = data.sum()          # 所有元素求和data2 = data.sum(dim=0)     # 按列求和data2 = data.sum(dim=1)     # 按行求和print("求和结果-->", data2)# 平方运算data3 = data.pow(2)data4 = torch.pow(data, 2)  # 两种写法等价print("平方-->", data3)# 开方运算data5 = data.sqrt()print("开方-->", data5)# 指数运算(以e为底)data6 = data.exp()print("指数-->", data6)# 对数运算data7 = data.log()    # 以e为底的对数data8 = data.log2()   # 以2为底的对数data9 = data.log10()  # 以10为底的对数print("对数-->", data7)

重要概念:

  • dim参数:指定沿哪个维度进行计算
  • 对于二维张量:dim=0表示按列计算,dim=1表示按行计算
  • 计算后指定维度会被"压缩"(从结果中移除)

1.2 数学运算的两种写法

PyTorch中很多运算都支持两种写法:

# 方法一:张量对象调用方法
data.pow(2)
data.mean(dim=0)# 方法二:torch模块调用函数
torch.pow(data, 2)
torch.mean(data, dim=0)

二、张量索引操作详解

索引操作是数据处理中的核心技能,PyTorch提供了多种灵活的索引方式:

2.1 简单行列索引

def dm01_simple_index():torch.manual_seed(2)data = torch.randint(0, 10, (4, 5))print("data-->\n", data)# 取第0行(三种等价写法)print("第0行-->", data[0])print("第0行-->", data[0,])print("第0行-->", data[0, :])# 取第0列print("第0列-->", data[:, 0])# 取第3列print("第3列-->", data[:, 3])

2.2 列表索引

def dm02_list_index():torch.manual_seed(2)data = torch.randint(0, 10, (4, 5))print("data-->\n", data)# 列表索引:从指定行和列中取值data2 = data[[[0], [1]], [1, 2]]# 这里是指从0行和1行中取第1列和第2列的数据print("列表索引结果-->", data2)

2.3 范围索引

def dm03_range_index():torch.manual_seed(2)data = torch.randint(0, 10, (4, 5))print("data-->\n", data)# 范围索引data3 = data[:2, :2]    # 前2行,前2列print("前2行前2列-->\n", data3)data3 = data[:2, 2:]    # 前2行,后3列print("前2行后3列-->\n", data3)data3 = data[2:, 2:]    # 后2行,后3列print("后2行后3列-->\n", data3)

2.4 布尔索引

布尔索引是数据筛选的强大工具:

def dm04_bool_index():torch.manual_seed(2)data = torch.randint(0, 10, (4, 5))print("data-->\n", data)# 布尔索引:过滤行# 第2列大于5的所有行data1 = data[data[:, 2] > 5]print("第2列>5的行-->", data1)# 布尔索引:过滤列# 第2行大于5的列data2 = data[:, data[1] > 5]print("第2行>5的列-->\n", data2)

2.5 多维索引

def dm05_multi_dim_index():torch.manual_seed(2)data = torch.randint(0, 10, (3, 4, 5))print("data-->\n", data)# 多维索引data1 = data[0, :, :]    # 取第一个二维矩阵data2 = data[:, 1, :]    # 取所有矩阵第1行data3 = data[:, :, 2]    # 取所有矩阵第2列print("所有矩阵第2列-->\n", data3)

三、张量形状变换详解

3.1 reshape()函数

def dm01_reshape():data = torch.tensor([[1, 2, 3], [4, 5, 6]])print("原始形状-->", data.shape)# reshape函数data1 = data.reshape(1, 6)    # 1行6列print("reshape(1,6)-->", data1, data1.shape)data1 = data.reshape(6)       # 一维张量print("reshape(6)-->", data1, data1.shape)# 使用-1自动推断维度data1 = data.reshape(-1)      # 自动推断为一维print("reshape(-1)-->", data1, data1.shape)

3.2 squeeze()和unsqueeze()函数

这两个函数用于增加或减少维度为1的维度:

def dm02_squeeze_unsqueeze():data = torch.tensor([1, 2, 3, 4, 5])print("原始数据-->", data, data.shape)# 升维:unsqueeze()data1 = data.unsqueeze(dim=0)    # 在第0维增加维度data2 = data.unsqueeze(dim=1)    # 在第1维增加维度data3 = data.unsqueeze(dim=-1)   # 在最后一维增加维度print("unsqueeze(0)-->", data1, data1.shape)print("unsqueeze(1)-->", data2, data2.shape)print("unsqueeze(-1)-->", data3, data3.shape)# 降维:squeeze()data5 = data1.squeeze()          # 移除所有维度为1的维度print("squeeze()-->", data5, data5.shape)

3.3 transpose()和permute()函数

这两个函数用于交换张量的维度:

def dm03_transpose_permute():data = torch.randint(0, 10, (3, 4, 5))print("原始数据-->", data.shape)# transpose:一次只能交换两个维度data1 = data.transpose(dim0=2, dim1=0)print("transpose后-->", data1.shape)# 使用负数索引data1 = data.transpose(dim0=-2, dim1=-1)print("transpose(-2,-1)-->", data1.shape)# permute:可以一次交换多个维度data2 = data.permute(2, 0, 1)print("permute后-->", data2.shape)  # (5, 3, 4)

选择建议:

  • 一次只交换2个维度:优先使用transpose
  • 一次交换多个维度:优先使用permute

3.4 view()和contiguous()函数

def dm04_view_contiguous():data = torch.tensor([[1, 2, 3], [4, 5, 6]])print("原始数据-->", data.shape)# view函数进行维度变换data1 = data.view(6)        # 变为一维data1 = data.view(3, 2)     # 变为3行2列data1 = data.view(3, -1)    # -1表示自动推断print("view后-->", data1.shape)# 先transpose再viewdata2 = data.transpose(0, 1)print("transpose后-->", data2.shape)# 检查内存是否连续print("是否连续-->", data2.is_contiguous())# 不连续的内存需要先contiguous()再viewdata3 = data2.contiguous().view(6)print("contiguous+view-->", data3)

四、张量拼接操作

4.1 cat()函数

def dm01_cat():# 生成两个三维张量data1 = torch.randint(0, 10, (1, 2, 3))data2 = torch.randint(0, 10, (1, 2, 3))# 拼接张量data3 = torch.cat([data1, data2], dim=0)  # 在第0维拼接print("dim=0拼接-->", data3.shape)  # [2, 2, 3]data3 = torch.cat([data1, data2], dim=1)  # 在第1维拼接print("dim=1拼接-->", data3.shape)  # [1, 4, 3]data3 = torch.cat([data1, data2], dim=2)  # 在第2维拼接print("dim=2拼接-->", data3.shape)  # [1, 2, 6]

拼接规则:

  • 除了拼接的维度,其他维度必须相同
  • catconcatconcatenate功能类似,掌握一个即可

五、自动微分机制详解

自动微分是PyTorch的核心特性,它能够自动计算梯度,这是深度学习训练的基础:

5.1 基本梯度计算

def dm01_calc_grad():# 创建需要计算梯度的张量x = torch.tensor([3, 4], requires_grad=True, dtype=torch.float)# 定义目标函数y = x ** 2print("y-->", y)# 计算梯度前print("初始梯度-->", x.grad)# 计算梯度y.sum().backward()  # 标量才能直接backwardprint("计算后梯度-->", x.grad)

5.2 多次循环计算梯度

def dm02_calc_grad():x = torch.tensor(3, requires_grad=True, dtype=torch.float)for i in range(50):# 定义目标函数y = x ** 2# 梯度清零(重要!)if x.grad is not None:x.grad.zero_()# 计算梯度y.backward()# 更新参数:w新 = w旧 - lr * gradx.data = x.data - 0.01 * x.gradprint(f"第{i+1}次: 梯度={x.grad}, 参数={x.data}")

5.3 计算图流逻辑

def dm03_calc_graph_flow():# 输入数据(不需要梯度)x = torch.tensor(5)y = torch.tensor(0.)# 模型参数(需要梯度)w = torch.tensor(3., requires_grad=True)b = torch.tensor(2., requires_grad=True)# 前向传播z = w * x + b# 损失函数criterion = nn.MSELoss()loss = criterion(z, y)print("损失-->", loss)# 反向传播loss.backward()print("w的梯度-->", w.grad)print("b的梯度-->", b.grad)

5.4 矩阵参数更新

def dm04_matrix():# 输入和输出x = torch.ones(2, 5)y = torch.zeros(2, 3)# 权重矩阵参数# x @ w + b = y# (2,5) @ (5,3) + (3) = (2,3)w = torch.randn((5, 3), requires_grad=True)b = torch.randn(3, requires_grad=True)# 前向传播z = torch.matmul(x, w) + b# 损失计算criterion = nn.MSELoss()loss = criterion(z, y)# 反向传播loss.backward()print("w的梯度-->", w.grad)print("b的梯度-->", b.grad)

六、线性回归实战案例

6.1 构建数据集

import matplotlib
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
from torch import nn, optimmatplotlib.use("TkAgg")def get_dataset():# 生成回归数据集x, y, coef = make_regression(n_samples=100,n_features=1,bias=2,noise=10,coef=True,random_state=22)# 转换为张量tensor_x = torch.tensor(x, dtype=torch.float)tensor_y = torch.tensor(y, dtype=torch.float)return tensor_x, tensor_y, coef

6.2 数据可视化

def show_data():x, y, coef = get_dataset()# 绘制散点图plt.scatter(x, y)# 绘制真实直线plt.plot(x.numpy(), coef * x.numpy() + 2)plt.show()

6.3 构建模型

def make_model():# 线性模型model = nn.Linear(in_features=1, out_features=1)# 损失函数criterion = nn.MSELoss()# 优化器optimizer = optim.SGD(params=model.parameters(), lr=0.01)return model, criterion, optimizer

七、重要概念总结

7.1 梯度计算要点

  1. requires_grad=True:只有设置此参数才能计算梯度
  2. 数据类型:计算梯度必须是浮点类型
  3. 梯度清零:每次迭代前需要清零历史梯度
  4. 标量backward:只有标量才能直接调用backward()

7.2 形状变换选择

  • reshape:改变张量形状,不改变数据
  • view:类似reshape,但要求内存连续
  • transpose:交换两个维度
  • permute:交换多个维度
  • squeeze/unsqueeze:增加或减少维度为1的维度

7.3 索引操作技巧

  • 简单索引data[i, j]
  • 范围索引data[start:end]
  • 列表索引data[[i, j], [k, l]]
  • 布尔索引data[condition]

八、实践建议

  1. 理解维度概念:掌握dim参数的含义和使用
  2. 掌握索引技巧:灵活使用各种索引方式
  3. 注意内存管理:理解contiguous()的作用
  4. 梯度计算规范:正确设置requires_grad和梯度清零
  5. 多动手实践:通过实际案例加深理解

结语

本文详细介绍了PyTorch Day02的核心内容,包括张量运算、索引操作、形状变换和自动微分机制。这些知识是构建神经网络模型的基础,掌握它们对于后续的深度学习学习至关重要。

在下一篇文章中,我们将学习更高级的神经网络构建和训练技巧。希望这篇文章对您的学习有所帮助!


关键词: PyTorch、张量运算、索引操作、形状变换、自动微分、梯度计算、线性回归

相关文章:

  • 深度学习基础Day01:PyTorch张量创建与操作详解

文章转载自:

http://Z1iV8bZO.ghsLr.cn
http://8vGZyVOe.ghsLr.cn
http://zIJ6Ubln.ghsLr.cn
http://wUm0bXkE.ghsLr.cn
http://AjFVjmXh.ghsLr.cn
http://QVjyBd1b.ghsLr.cn
http://8GuBRVyJ.ghsLr.cn
http://K1tXeRXg.ghsLr.cn
http://gjwss1Q4.ghsLr.cn
http://0KGJ2i3q.ghsLr.cn
http://jJb0DfbS.ghsLr.cn
http://lFnI9Wp8.ghsLr.cn
http://i9XnaLpV.ghsLr.cn
http://dcCdrt1M.ghsLr.cn
http://bpgbCGZa.ghsLr.cn
http://zrAFj3ez.ghsLr.cn
http://DkmCTyQX.ghsLr.cn
http://MTIy5oYi.ghsLr.cn
http://dmJM9PhA.ghsLr.cn
http://wRuqiA1Y.ghsLr.cn
http://MU5rfM27.ghsLr.cn
http://L1FossJc.ghsLr.cn
http://NcufJx6k.ghsLr.cn
http://WvanTv1e.ghsLr.cn
http://mVEIpM5Y.ghsLr.cn
http://7Oimf2tb.ghsLr.cn
http://ODyReoFK.ghsLr.cn
http://muy6Z7Vt.ghsLr.cn
http://Gyqu4irS.ghsLr.cn
http://k2EHWcBR.ghsLr.cn
http://www.dtcms.com/a/388593.html

相关文章:

  • Simulink变量优先级与管理策略
  • 大模型学习:什么是FastText工具
  • 从芯片到云:微软Azure全栈硬件安全体系构建可信基石
  • 当文件传输遇上网络波动:如何实现稳定高效的数据交换
  • C++访问限定符private、public、protected的使用场景
  • springboot 使用CompletableFuture多线程调用多个url接口,等待所有接口返回后统一处理接口返回结果
  • 科普:build与make
  • 对比OpenCV GPU与CPU图像缩放的性能与效果差异
  • 网络工程师行业新技术新概念
  • 【Linux】Linux中dos2unix 工具转换文件格式
  • 实验4:表单控件绑定(2学时)
  • QT OpenCV 准备工具
  • 无锁化编程(Lock-Free Programming)分析
  • Centons7 docker 安装 playwright
  • 远距离传输大型文件:企业数字化转型的挑战与突破
  • 氧气科技亮相GDMS全球数字营销峰会,分享AI搜索时代GEO新观
  • useMemo和useCallback
  • 【数据结构---图的原理与最小生成树算法,单源最短路径算法】
  • 有发声生物(猫狗鸟等)与无发声生物(蚂蚁蝙蝠蛇等)的 “感知-->行动“
  • CC 攻击为什么越来越难防?
  • 量化交易 - Multiple Regression 多变量线性回归(机器学习)
  • 【机器学习】基于双向LSTM的IMDb情感分析
  • CLR-GAN训练自己的数据集
  • LeetCode 242 有效的字母异位词
  • 中州养老:Websocket实现报警通知
  • python+excel实现办公自动化学习
  • 深度学习快速复现平台AutoDL
  • 《股票智能查询与投资决策辅助应用项目方案》
  • nvm安装包分享【持续更新】
  • 2025年- H143-Lc344. 反转字符串(字符串)--Java版