当前位置: 首页 > news >正文

PyTorch的基础概念和复杂模型的基本使用

文章目录

    • 一、PyTorch基础概念
    • 二、复杂模型的学习使用

一、PyTorch基础概念

  1. 张量(Tensor)操作
    • 张量是PyTorch中的基本数据结构,类似于NumPy的数组,但支持GPU加速
    • 常见操作包括创建张量、张量运算、索引、切片等
import torch# 创建张量
x = torch.randn(3, 4)
y = torch.zeros(3, 4)# 张量运算
z = x + y
  1. 自动求导(Autograd)
    • PyTorch的自动求导系统可以自动计算梯度
    • 通过requires_grad=True启用梯度计算
# 启用自动求导
x = torch.randn(3, 4, requires_grad=True)# 计算损失
y = x * 2
loss = y.sum()# 反向传播
loss.backward()
  1. 计算图
    • PyTorch使用动态计算图(Define-by-Run)的方式
    • 每次前向传播都会构建一个新的计算图

二、复杂模型的学习使用

  1. 神经网络模块(torch.nn)
    • torch.nn提供了构建神经网络所需的各种组件
    • 主要包括各种层(如线性层、卷积层)、激活函数、损失函数等
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x
  1. 卷积神经网络(CNN)
    • 适用于图像处理任务
    • 包含卷积层、池化层等
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.fc1 = nn.Linear(12*12*64, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(-1, 12*12*64)x = F.relu(self.fc1(x))x = self.fc2(x)return x
  1. 循环神经网络(RNN)
    • 适用于序列数据处理任务
    • 包括RNN、LSTM、GRU等变体
class RNNModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNNModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])return out
  1. 训练流程
    • 数据加载:使用DataLoaderDataset加载数据
    • 模型定义:定义神经网络结构
    • 损失函数:选择合适的损失函数(如交叉熵损失)
    • 优化器:选择优化器(如Adam)并传入模型参数
    • 训练循环:执行前向传播、计算损失、反向传播和参数更新
from torch.utils.data import DataLoader, TensorDataset# 创建数据集
x_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)# 创建模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练循环
for epoch in range(10):for inputs, targets in dataloader:outputs = model(inputs)loss = criterion(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()
  1. 模型保存与加载
    • 使用torch.save()torch.load()保存和加载模型
# 保存模型
torch.save(model.state_dict(), "model.pth")# 加载模型
model = Net()
model.load_state_dict(torch.load("model.pth"))
http://www.dtcms.com/a/289311.html

相关文章:

  • 【软件测试】从软件测试到Bug评审:生命周期与管理技巧
  • ESXi6.7硬件传感器红色警示信息
  • ICT模拟零件测试方法--测量参数详解
  • ThinkPHP8极简上手指南:开启高效开发之旅
  • 基于机器视觉的迈克耳孙干涉环自动计数系统设计与实现
  • STM32CubeMX的一些操作步骤的作用
  • 拼写纠错模型Noisy Channel(下)
  • 机器学习理论基础 - 核心概念篇
  • 复杂度优先:基于推理链复杂性的提示工程新范式
  • Linux操作系统之线程(四):线程控制
  • 20250720-1-Kubernetes 调度-白话理解创建一个Pod的内部工作流_笔记
  • Qt的安装和环境配置
  • Ubuntu挂载和取消挂载
  • 【vue-7】Vue3 响应式数据声明:深入理解 reactive()
  • Matlab自学笔记六十四:求解自变量带有约束条件的方程
  • 相同问题的有奇点模型和无奇点模型有什么区别
  • 服务器上的文件复制到本地 Windows 系统
  • [学习] 深入理解傅里叶变换:从时域到频域的桥梁
  • 04训练windows电脑低算力显卡如何部署pytorch实现GPU加速
  • LINUX720 SWAP扩容;新增逻辑卷;逻辑卷扩容;数据库迁移;gdisk
  • 【超越VGGT】π3-利用置换等变方法去除3r系列的归纳偏置
  • 机器视觉---深度图像存储格式
  • 监督学习应用
  • 零基础学习性能测试第三章:执行性能测试
  • Spring Boot 订单超时自动取消的 3 种主流实现方案
  • 将SAC强化学习算法部署到ROS2的完整指南
  • 基于卷积傅里叶分析网络 (CFAN)的心电图分类的统一时频方法
  • 复杂度+包装类型+泛型
  • 全平台爬虫配置流程
  • Spark专栏开篇:它从何而来,为何而生,凭何而强?