当前位置: 首页 > news >正文

Pytorch常用API(ML和DL)

在 PyTorch 中,API 设计围绕 “张量操作”“自动求导”“神经网络构建”“训练工具” 等核心模块展开,以下是 ML/DL 中最常用的 API 分类及关键功能介绍,附带典型用法场景:

一、核心:张量(Tensor)操作(torch

张量是 PyTorch 的基础数据结构(类似 NumPy 数组,但支持 GPU 加速和自动求导),所有计算围绕张量展开。

1. 张量创建
  • torch.tensor(data, dtype=None, device=None):从数据(列表 / NumPy 数组)创建张量,指定数据类型(如torch.float32)和设备(CPU/GPU)。例:x = torch.tensor([[1,2],[3,4]], dtype=torch.float32)

  • 快捷创建:

    • torch.zeros(size):全 0 张量(如torch.zeros(2,3)
    • torch.ones(size):全 1 张量
    • torch.randn(size):标准正态分布(均值 0,方差 1)
    • torch.arange(start, end, step):等差数列(如torch.arange(0,10,2)
2. 张量属性与设备
  • tensor.shape / tensor.size():张量维度(如(2,3)表示 2 行 3 列)
  • tensor.dtype:数据类型(如torch.float32torch.int64
  • tensor.device:存储设备(cpucuda:0
  • 设备切换:tensor.to('cuda') 或 tensor.cuda()(移到 GPU);tensor.cpu()(移回 CPU)
3. 常用操作
  • 重塑tensor.reshape(new_shape)(灵活调整维度)、tensor.view(new_shape)(要求连续内存)例:x = torch.randn(4,3); x.reshape(2,6)

  • 拼接 / 堆叠

    • torch.cat([t1, t2], dim=0):沿指定维度拼接(维度不变,长度增加)
    • torch.stack([t1, t2], dim=0):新增维度堆叠(维度 + 1)
  • 数学运算

    • 逐元素:torch.add(a,b)(或a+b)、torch.mul(a,b)(或a*b
    • 矩阵乘法:torch.matmul(a,b)(或a @ b,支持高维矩阵)
    • 聚合:torch.sum(tensor, dim=0)(按维度求和)、torch.mean()torch.max()

二、自动求导(torch.autograd

深度学习的核心机制,自动计算张量的梯度(用于反向传播更新参数)。

  • 启用梯度跟踪:创建张量时指定requires_grad=True,或用tensor.requires_grad_()启用。例:x = torch.tensor([2.0], requires_grad=True)

  • 前向传播与梯度计算

    • 定义计算图(如y = x**2 + 3*x
    • 反向传播:y.backward()(自动计算y对所有requires_grad=True的张量的梯度)
    • 获取梯度:x.grad(存储dy/dx的结果)
  • 停止梯度跟踪

    • tensor.detach():返回不跟踪梯度的张量副本
    • with torch.no_grad()::上下文管理器,内部操作不记录梯度(推理时用,节省内存)

三、神经网络模块(torch.nn

用于快速构建深度学习模型,包含预定义层、激活函数、损失函数等。

1. 模型基类 nn.Module

所有自定义模型必须继承此类,通过forward()定义前向传播逻辑。

例:定义一个简单的线性回归模型

import torch.nn as nnclass LinearModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(in_features=1, out_features=1)  # 线性层:y = wx + bdef forward(self, x):return self.linear(x)  # 前向传播:输入x -> 线性层输出
2. 常用网络层
  • 全连接层nn.Linear(in_features, out_features)(用于 MLP)
  • 卷积层nn.Conv2d(in_channels, out_channels, kernel_size)(2D 卷积,用于 CNN)
  • 池化层nn.MaxPool2d(kernel_size)(最大池化,降维)
  • 循环层nn.LSTM(input_size, hidden_size, num_layers)(长短期记忆网络,用于序列数据)
  • 激活函数nn.ReLU()nn.Sigmoid()nn.Softmax(dim=1)(引入非线性)
  • 正则化层nn.Dropout(p=0.5)(随机失活,防止过拟合)、nn.BatchNorm2d(num_features)(批归一化)
3. 损失函数(nn.loss

定义模型预测与真实标签的差距,用于反向传播更新参数。

  • 回归任务:nn.MSELoss()(均方误差)、nn.L1Loss()(平均绝对误差)
  • 分类任务:
    • nn.CrossEntropyLoss()(交叉熵,适用于多分类,内置 Softmax)
    • nn.BCELoss()(二元交叉熵,输入需经 Sigmoid)
    • nn.BCEWithLogitsLoss()(结合 Sigmoid 和 BCELoss,更稳定)

四、优化器(torch.optim

用于根据梯度更新模型参数,最小化损失函数。

  • 常用优化器:

    • optim.SGD(params, lr=0.01, momentum=0.9):随机梯度下降(带动量加速收敛)
    • optim.Adam(params, lr=0.001):自适应学习率,适合多数场景(默认首选)
    • optim.RMSprop()optim.Adagrad()
  • 用法流程:

model = LinearModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 传入模型参数和学习率# 训练时:
optimizer.zero_grad()  # 清空上一轮梯度
loss = criterion(output, target)  # 计算损失
loss.backward()  # 反向传播求梯度
optimizer.step()  # 更新参数

五、数据加载(torch.utils.data

用于高效加载和预处理数据,支持批处理、多线程加载。

  • Dataset:抽象类,需自定义__len__()(数据总数)和__getitem__()(按索引取数据)。例:自定义数据集

from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, x, y):self.x = x  # 特征self.y = y  # 标签def __len__(self):return len(self.x)def __getitem__(self, idx):return self.x[idx], self.y[idx]
  • DataLoader:包装Dataset,支持批处理、打乱、多线程:
from torch.utils.data import DataLoaderdataset = MyDataset(x_data, y_data)
dataloader = DataLoader(dataset,batch_size=32,  # 每批32个样本shuffle=True,   # 训练时打乱数据num_workers=4   # 4个线程加载数据
)

六、模型保存与加载

  • 保存:torch.save(model.state_dict(), "model.pth")(推荐只保存参数,轻量)
  • 加载:
model = LinearModel()
model.load_state_dict(torch.load("model.pth"))  # 加载参数
model.eval()  # 切换到评估模式(关闭Dropout/BatchNorm等)

七、其他实用工具

  • torchvision.transforms:图像预处理工具(如ToTensor()转张量、Resize()调整大小、Normalize()标准化)。
  • torch.no_grad():推理时关闭梯度计算(节省内存)。
  • nn.Sequential:按顺序组合层,简化模型定义(如nn.Sequential(nn.Linear(2,3), nn.ReLU()))。

总结

PyTorch 的 API 设计注重灵活性,核心逻辑可概括为:数据(张量)→ 模型(nn.Module)→ 损失(nn.Loss)→ 优化器(optim)→ 训练循环(前向传播→反向传播→参数更新)

http://www.dtcms.com/a/528865.html

相关文章:

  • 切水题2.0
  • 深入解析C++ String类的实现奥秘
  • 机器视觉的液晶电视OCA全贴合应用
  • 个人博客网站页面儿童玩具网站建设策划书
  • 构建大模型安全自动化测试框架:从手工POC到AI对抗AI的递归Fuzz实践
  • 数据库约束与查询:MySQL 中的 DQL 和约束全解析
  • C++笔记(面向对象)友元
  • 网站在工信部备案查询oa系统开发
  • FPGA基础知识(七):引脚约束深度解析--从物理连接到时序收敛的完整指南
  • Minecraft-Speed-Proxy——搭建专属的Minecraft加速IP
  • Flutter 异步 + 状态管理融合实践:Riverpod 与 Bloc 双方案解析
  • 10.25复习LRU缓存[特殊字符]
  • 做网站怎么那么难谷歌关键词排名查询工具
  • 门户网站的建设与维护注册域名多长时间
  • 实战:将 Nginx 日志实时解析并写入 MySQL,不再依赖 ELK
  • Redis 黑马点评day02 商户查询缓存
  • 品牌网站建设切入点wordpress很好的博客
  • ASP.NET Core读取Excel文件
  • 器材管理网站开发沈阳网站建设费用
  • 巧用 CSS linear-gradient 实现多种下划线文字特效(纯 CSS 无需额外标签)
  • 地州电视网站建设流程网址域名大全
  • 计算机网络自顶向下方法 1——因特网的介绍及构成 介绍协议
  • 学习笔记|受限波尔兹曼机(RBM)
  • DiVE长尾识别的虚拟实例蒸馏方法
  • 视频网站很难建设吗珠海网站运营
  • h5游戏免费下载:废柴勇士
  • 简单的企业网站源码网站建设业务
  • 基于鸿蒙 UniProton 的汽车电子系统开发指南
  • 建设部质监局网站电子商务网站策划书2000字
  • 使用表达式树实现字符串形式的表达式访问对象属性