当前位置: 首页 > news >正文

提升准确率的处理

# 第一部分:导入模块、定义超参数和模型结构
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from collections import CounterBATCHSIZE = 100
DOWNLOAD_MNIST = False
EPOCHES = 20
LR = 0.001class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(1296, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 36 * 6 * 6)x = F.relu(self.fc2(F.relu(self.fc1(x))))return xclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 36, 5)self.pool2 = nn.MaxPool2d(2, 2)self.aap = nn.AdaptiveAvgPool2d(1)self.fc3 = nn.Linear(36, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.aap(x)x = x.view(x.shape[0], -1)x = self.fc3(x)return xclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out# 第二部分:数据准备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print('==> Preparing data..')
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 第三部分:模型、损失函数和优化器定义
print('==> Building model..')
net1 = CNNNet()
net2 = Net()
net3 = LeNet()import torch.optim as optim
LR = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# 第四部分:模型训练
for epoch in range(10):running_loss = 0.0for i, data in enumerate(trainloader, 0):# 获取训练数据inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 权重参数梯度清零optimizer.zero_grad()# 正向及反向传播outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 显示损失值running_loss += loss.item()if i % 2000 == 1999:  # print every 2000 mini-batchesprint('[Epoch: %d, Batch: %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0
print('Finished Training')# 第五部分:显示各层参数
import collections
def params_summary(input_size, model):def register_hook(module):def hook(module, input, output):class_name = str(module.__class__).split('.')[-1].split("'")[0]module_idx = len(summary)m_key = '%s-%i' % (class_name, module_idx + 1)summary[m_key] = collections.OrderedDict()summary[m_key]['input_shape'] = list(input[0].size())summary[m_key]['input_shape'][0] = -1summary[m_key]['output_shape'] = list(output.size())summary[m_key]['output_shape'][0] = -1params = 0if hasattr(module, 'weight'):params += torch.prod(torch.LongTensor(list(module.weight.size())))if module.weight.requires_grad:summary[m_key]['trainable'] = Trueelse:summary[m_key]['trainable'] = Falseif hasattr(module, 'bias'):params += torch.prod(torch.LongTensor(list(module.bias.size())))summary[m_key]['nb_params'] = paramsif not isinstance(module, nn.Sequential) and \not isinstance(module, nn.ModuleList) and \not (module == model):hooks.append(module.register_forward_hook(hook))# check if there are multiple inputs to the networkif isinstance(input_size[0], (list, tuple)):x = [torch.rand(1, *in_size) for in_size in input_size]else:x = torch.rand(1, *input_size)# create propertiessummary = collections.OrderedDict()hooks = []model.apply(register_hook)# make a forward passmodel(x)# remove these hooksfor h in hooks:h.remove()return summary

数据处理与加载

  • 数据增强与预处理
    • transforms.RandomCrop(32, padding=4):对图像进行随机裁剪,裁剪后尺寸为 32×32,且在裁剪前在图像四周填充 4 个像素,这样可以增加数据的多样性,减少过拟合。
    • transforms.RandomHorizontalFlip():以一定概率(默认 0.5)对图像进行水平翻转,进一步扩充训练数据。
    • transforms.ToTensor():将 PIL 图像或 numpy 数组转换为 PyTorch 张量,并且将图像像素值从 [0,255] 归一化到 [0,1]。
    • transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)):对张量进行标准化,使用 CIFAR10 数据集的均值和标准差,使得数据分布更符合正态分布,有利于模型训练。
  • 数据集与数据加载器
    • torchvision.datasets.CIFAR10:加载 CIFAR10 数据集,root指定数据集存储路径,train参数区分训练集和测试集,download指定是否下载数据集(这里设为False表示使用本地已下载的数据集),transform指定对数据的预处理操作。
    • torch.utils.data.DataLoader:将数据集包装成可迭代的 DataLoader,batch_size指定每个批次的样本数量,shuffle指定是否在每个 epoch 前打乱数据,num_workers指定用于数据加载的子进程数量,加快数据加载速度。

2. 模型构建

  • nn.Module基类:所有的 PyTorch 模型都要继承nn.Module类,在__init__方法中定义模型的层结构,在forward方法中定义前向传播逻辑。
  • 卷积层(nn.Conv2d:用于提取图像的局部特征,参数包括in_channels(输入通道数)、out_channels(输出通道数)、kernel_size(卷积核大小)、stride(步长)等。例如nn.Conv2d(3, 16, 5)表示输入通道数为 3,输出通道数为 16,卷积核大小为 5×5,步长为 1。
  • 池化层
    • nn.MaxPool2d:最大池化层,对输入的特征图进行下采样,保留每个池化窗口中的最大值,减少参数数量和计算量,同时保持特征的主要信息。参数kernel_size为池化窗口大小,stride为池化步长。
    • nn.AdaptiveAvgPool2d:自适应平均池化层,将输入特征图调整到指定的输出尺寸,这里nn.AdaptiveAvgPool2d(1)是将特征图调整为 1×1 的大小,方便后续与全连接层连接。
  • 全连接层(nn.Linear:将卷积和池化得到的特征进行线性变换,用于分类等任务。参数in_features为输入特征数,out_features为输出特征数,例如nn.Linear(1296, 128)表示输入 1296 个特征,输出 128 个特征。
  • 激活函数(F.relu:ReLU(Rectified Linear Unit)激活函数,公式为\(f(x)=\max(0,x)\),可以引入非线性,增加模型的表达能力,解决线性模型无法拟合复杂数据的问题。

3. 模型训练

  • 优化器(optim.SGD:随机梯度下降优化器,用于更新模型参数以最小化损失函数。参数lr为学习率,控制参数更新的步长;momentum为动量,有助于加速梯度下降过程,特别是在面对局部极小值或鞍点时,能使优化过程更稳定。
  • 损失函数(nn.CrossEntropyLoss:用于多分类任务的损失函数,它结合了nn.LogSoftmaxnn.NLLLoss(负对数似然损失),计算预测概率分布与真实标签之间的交叉熵损失,公式为\(L=-\sum_{i}y_{i}\log(\hat{y}_{i})\),其中\(y_{i}\)是真实标签的 one - hot 编码,\(\hat{y}_{i}\)是模型预测的概率。
  • 训练循环
    • 外层循环(epoch循环):控制训练的轮数,每一轮会遍历整个训练数据集。
    • 内层循环(batch循环):遍历 DataLoader 中的每个批次数据。
    • optimizer.zero_grad():在每次计算梯度前,将优化器中保存的梯度清零,因为 PyTorch 会累积梯度,如果不清零会导致梯度计算错误。
    • outputs = net(inputs):前向传播,将输入数据传入模型,得到预测输出。
    • loss = criterion(outputs, labels):计算预测输出与真实标签之间的损失。
    • loss.backward():反向传播,计算损失关于模型参数的梯度。
    • optimizer.step():根据计算得到的梯度,更新模型的参数。
    • 损失监控:通过running_loss累积每个批次的损失,当达到指定的批次间隔(这里是 2000 个 mini - batches)时,打印损失值,用于观察模型训练过程中损失的变化情况,判断模型是否在收敛。

4. 模型参数查看

  • 前向传播钩子(register_forward_hook:通过给模型的每个模块注册前向传播钩子函数,在模型前向传播过程中,可以获取每个模块的输入、输出以及参数等信息。
  • params_summary函数
    • 首先处理输入数据的形状,生成随机的输入张量用于前向传播测试。
    • 然后通过model.apply(register_hook)遍历模型的每个模块,为每个模块注册钩子函数。
    • 在钩子函数中,记录模块的类名、输入输出形状、可训练性(通过module.weight.requires_grad判断)以及参数数量(通过计算module.weightmodule.bias的元素数量得到)。
    • 最后执行前向传播model(x),触发钩子函数记录信息,并在结束后移除钩子,避免对后续操作产生影响。通过这个函数可以详细了解模型各层的结构和参数情况,有助于模型的调试和分析。

5. 设备利用

  • torch.device('cuda' if torch.cuda.is_available() else 'cpu'):检测当前环境是否有可用的 CUDA 设备(即 GPU),如果有则使用 GPU 进行计算,否则使用 CPU。在训练过程中,通过inputs.to(device)labels.to(device)将数据移动到指定设备,通过net.to(device)将模型移动到指定设备,利用 GPU 的并行计算能力可以显著加快模型训练的速度。

http://www.dtcms.com/a/419678.html

相关文章:

  • 透明水印logo在线制作东莞市seo网络推广报价
  • App 上架服务全流程解析,iOS 应用代上架、ipa 文件上传工具、TestFlight 测试与苹果审核实战经验
  • 织梦网站版权银行营销活动方案
  • 自己做视频网站会不会追究版权做网站界面一般用什么来做
  • less和sass
  • 单片机开发---RP2040数据手册之PIO功能
  • 怎么免费做网站视频教学网站不收录 域名问题
  • 青海省城乡建设厅网站首页网站缩放代码
  • 学习2025.9.28
  • C++协程
  • 模电基础:多级放大电路与集成运放的认识
  • 汕头网站推广教程.电子商务网站规划
  • 深入理解哈希表:闭散列与开散列两种实现方案解析
  • 无锡网站推广公司排名线下推广都有什么方式
  • Linux从入门到精通——基础指令篇(耐人寻味)
  • 网站建设 运维 管理包括哪些公众号开发者中心在哪
  • IDEA AI Agent
  • 有没有帮人做数学题的网站现在网站建设都用什么语言
  • 解决Ubuntu22.04 安装telnetd ubuntu入门之二十九
  • 个人网站怎么写网站哪里可以做
  • 嵌入式linux内核驱动学习2——linux启动流程
  • 机械网站案例分析wordpress导航栏插件
  • 大姚县建设工程招标网站云平台网站叫什么
  • mysql独立表空间迁移
  • 泸州网站建设价格高端网站建设公司排名
  • 实战:SQL统一访问200+数据源,构建企业级智能检索与RAG系统(下)
  • 免费公司主页网站开源seo软件
  • 创建网站需要学什么知识四川省建设监理协会网站
  • Android Studio历史版本下载
  • Vue3 + TypeScript + Ant Design Vue 实战:密码表单校验与拓展功能(强度提示 + 显示/隐藏密码)