当前位置: 首页 > news >正文

通过构建神经网络实现项目预测

一、PyTorch

  1. 数据集与数据加载

    • MNIST:PyTorch 内置的手写数字数据集类,train=True加载训练集,train=False加载测试集。
    • transforms.Compose:组合多个数据预处理操作。例:transforms.ToTensor()(将图像转为张量)、transforms.Normalize([0.5], [0.5])(标准化,均值 = 0.5,标准差 = 0.5)。
    • DataLoader:创建数据加载器,实现批量加载(batch_size)、打乱数据(shuffle=True)、多进程加载等功能。
  2. 神经网络构建

    • nn.Module:所有神经网络模型的基类,自定义模型需继承此类并实现__init__(定义层)和forward(前向传播)方法。
    • 常用层:
      • nn.Flatten():将多维张量展平为一维(例如 28×28 图像→784 维向量)。
      • nn.Linear(in_features, out_features):全连接层,实现y = x·W + b运算。
      • nn.BatchNorm1d(num_features):批归一化层,加速训练并稳定梯度(对每个批次数据标准化)。
    • 激活函数:
      • F.relu(x):ReLU 激活函数,引入非线性,relu(x) = max(0, x)
      • F.softmax(x, dim=1):将输出转为概率分布(沿 dim=1 维度求和为 1),用于多分类任务。
  3. 设备配置

    • torch.device("cuda:0" if torch.cuda.is_available() else "cpu"):自动选择运行设备(优先 GPU,若无则用 CPU)。
    • model.to(device):将模型参数迁移到指定设备;img.to(device):将数据迁移到同一设备(模型和数据需在同设备上运算)。
  4. 损失函数与优化器

    • nn.CrossEntropyLoss():交叉熵损失,常用于分类任务(结合了nn.LogSoftmaxnn.NLLLoss)。
    • optim.SGD(model.parameters(), lr=0.01, momentum=0.9):随机梯度下降优化器,lr为学习率,momentum为动量(加速收敛,减少震荡)。
    • 优化器操作:
      • optimizer.zero_grad():清空上一轮梯度(避免累积)。
      • loss.backward():反向传播计算梯度。
      • optimizer.step():根据梯度更新模型参数。
  5. 模型训练与评估模式

    • model.train():切换为训练模式(启用 dropout、批归一化的训练行为)。
    • model.eval():切换为评估模式(禁用 dropout,批归一化使用移动均值 / 方差)。

二、训练

  1. 动态调整学习率

    • optimizer.param_groups[0]['lr'] *= 0.9:每 5 个 epoch 将学习率乘以 0.9(衰减学习率,避免后期震荡)。
  2. 准确率计算

    • out.max(1):返回沿第 1 维度的最大值和索引(索引对应预测的类别)。
    • (pred == label).sum().item():统计预测正确的样本数(pred为预测类别,label为真实类别)。
    • 准确率 = 正确样本数 / 总样本数(num_correct / img.shape[0])。
  3. 训练日志与可视化

    • SummaryWriter:TensorBoard 日志记录工具,writer.add_scalar()记录损失等指标随 epoch 的变化。
    • matplotlib.pyplot:绘制训练损失曲线和数据样本可视化(如手写数字图像)。

可执行代码PyThon版

import numpy as np
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt# 定义一些超参数
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epochs = 20# 定义预处理函数
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
# 下载数据,并对数据进行预处理
train_dataset = MNIST('../data/', train=True, transform=transform, download=True)
test_dataset = MNIST('../data/', train=False, transform=transform)
# 得到一个生成器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)# 可视化源数据
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)fig = plt.figure()
for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(example_data[i][0], cmap='gray', interpolation='none')plt.title("Ground Truth: {}".format(example_targets[i]))plt.xticks([])plt.yticks([])
plt.show()# 构建模型
class Net(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Net, self).__init__()self.flatten = nn.Flatten()self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = self.flatten(x)x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))x = F.softmax(self.out(x), dim=1)return x# 实例化模型、定义损失函数和优化器
lr = 0.01
momentum = 0.9
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net(28 * 28, 300, 100, 10)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)# 开始训练
losses = []
acces = []
eval_losses = []
eval_acces = []
writer = SummaryWriter(log_dir='logs', comment='train-loss')for epoch in range(num_epochs):train_loss = 0train_acc = 0model.train()# 动态修改参数学习率if epoch % 5 == 0:optimizer.param_groups[0]['lr'] *= 0.9print('学习率:{:.6f}'.format(optimizer.param_groups[0]['lr']))for img, label in train_loader:img = img.to(device)label = label.to(device)# 正向传播out = model(img)loss = criterion(out, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.item()# 保存loss的数据与epoch数值writer.add_scalar('Train', train_loss / len(train_loader), epoch)# 计算分类的准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]train_acc += acclosses.append(train_loss / len(train_loader))acces.append(train_acc / len(train_loader))# 在测试集上检验效果eval_loss = 0eval_acc = 0model.eval()for img, label in test_loader:img = img.to(device)label = label.to(device)img = img.view(img.size(0), -1)out = model(img)loss = criterion(out, label)# 记录误差eval_loss += loss.item()# 记录准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]eval_acc += acceval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader),eval_loss / len(test_loader), eval_acc / len(test_loader)))# 绘制训练损失曲线
plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')
plt.show()

http://www.dtcms.com/a/477941.html

相关文章:

  • 沈阳学网站制作学校百度应用搜索
  • 从零搭建鸿蒙高效数据存储框架:RdbStore全流程实战与性能优化
  • 图像处理-opencv(二)-形态学
  • 数字资产反诈指南:识别套路,守护WEEX账户安全
  • 深入剖析:Playwright MCP Server 的工作机制与性能优化策略
  • 下载好了网站模板怎么开始做网站一家专门做男人的网站
  • 记一次顽固eazyExcel异常排查
  • 网站的站点的管理系统手机网站 微信网站 区别
  • CentOS 7的内网环境中将OpenSSH升级到较高版本
  • 用你本地已有的私钥(private key)去 SSH 登录远程 Ubuntu 服务器
  • Ruby小白学习路线
  • 做网站项目需求分析是什么网站制作的评价指标
  • 普陀营销型网站建设微信登录界面
  • 一文入门Rust语言
  • FFmpeg开发笔记(十三):ffmpeg采集麦克风音频pcm重采样为aac录音为AAC文件
  • 深度学习实战:基于 PyTorch 的 MNIST 手写数字识别
  • 字符串逆序的优雅实现:双指针法的巧妙应用
  • [GO]golang接口入门:从一个简单示例看懂接口的多态与实现
  • 文章管理系统CMS的XSS注入渗透测试(白盒)
  • 主机做网站服务器吗成都网站建设服务功能
  • 北京网站关键词优化南昌网站建设哪家比较好
  • 前端Vue 后端ASP.NET Core WebApi 本地调试交互过程
  • KeepMouseSpeedOK:专业鼠标速度调节工具
  • leetcode 169. 多数元素
  • 沟通交流类网站有哪些ui外包网站
  • LeetCode——双指针(进阶)
  • SQL Server 2019实验 │ 安装及其管理工具的使用
  • RAGE框架:确保AI Prompt高效率高质量输出
  • aspcms 你的网站未安装 请先安装qq wordpress登陆
  • 广州白云做网站的公司百度推广有哪些形式