当前位置: 首页 > news >正文

实战训练1笔记

导入库
import torch
import torchvision
import torchvision.transforms as transforms
  • 导入PyTorch库及其视觉模块和转换模块。

数据预处理与加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  • 定义数据预处理步骤:转换为Tensor并归一化。

  • 加载CIFAR-10训练集和测试集,设置批大小为4,使用2个工作线程。

  • 定义类别标签。

图像显示函数
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inlinedef imshow(img):img = img / 2 + 0.5  # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()dataiter = iter(trainloader)
images, labels = next(dataiter)imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
  • 导入Matplotlib库用于图像显示。

  • 定义imshow函数用于显示图像。

  • 从训练加载器中获取一批图像和标签。

  • 显示图像并打印标签。

构建卷积神经网络
import torch.nn as nn
import torch.nn.functional as Fclass CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(16, 36, kernel_size=3, stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(1296, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 36*6*6)x = F.relu(self.fc2(F.relu(self.fc1(x))))return xnet = CNNNet()
net = net.to(device)
  • 定义卷积神经网络结构,包括两个卷积层、两个池化层和两个全连接层。

  • 实例化网络并将其移动到指定设备(GPU或CPU)。

定义损失函数和优化器
import torch.optim as optimLR = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#optimizer = optim.Adam(net.parameters(), lr=LR)
  • 定义交叉熵损失函数。

  • 定义随机梯度下降优化器,学习率为0.001

打印网络结构
print(net)
  • 打印网络结构。

提取模型中的前四层
nn.Sequential(*list(net.children())[:4])
  • 提取并打印网络的前四层。

训练模型
for epoch in range(10):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')
  • 进行10个epoch的训练。

  • 在每个epoch中,遍历训练数据加载器。

  • 将输入和标签移动到设备。

  • 清零梯度,进行前向传播、计算损失、反向传播和优化器更新。

  • 每2000个小批量打印一次损失。

显示测试集图像
dataiter = iter(testloader)
for i, (images, labels) in enumerate(dataiter):imshow(torchvision.utils.make_grid(images))print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
  • 从测试加载器中获取一批图像和标签。

  • 显示图像并打印真实标签。

预测测试集结果
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  • 将图像和标签移动到设备。

  • 进行前向传播,获取输出。

  • 获取预测结果并打印。

总结

  • 导入必要的库。

  • 定义数据预处理步骤并加载CIFAR-10数据集。

  • 定义图像显示函数并显示一批图像及其标签。

  • 构建卷积神经网络结构。

  • 定义损失函数和优化器。

  • 打印网络结构并提取前四层。

  • 进行模型训练并每2000个小批量打印一次损失。

  • 显示测试集图像及其真实标签。

  • 预测测试集结果并打印。

http://www.dtcms.com/a/411399.html

相关文章:

  • 网站制作程序下载ngo网页模板下载
  • C++学习记录(13)二叉排序树
  • TongWeb下如何获取数据源的物理连接?
  • 保险资料网站有哪些三网合一网站建设报价
  • 网站建设系统分析ai的优点和缺点
  • 三网合一网站百度一下免费下载
  • 坤驰科技携数据采集解决方案,亮相中国光纤传感大会
  • 可以做免费的网站吗广州平面设计工作室
  • 【文献阅读】基于机器学习的网络最差鲁棒性可扩展快速评估框架
  • 【复习】计网每日一题--PPP协议透明传输
  • 【训练技巧】torch.amp.GradScaler 里面当scale系数为0或者非常小的时候,详细分析与解决思路
  • 一站式服务logo设计深圳网站建设服务商哪些好?
  • 专业的网站建设公司电话做商城网站要什么手续
  • mdBook 开源笔记
  • 【1、Kotlin 基础语法】2、Kotlin 变量
  • TorchV知识库安全解决方案:基于智能环境感知的动态权限控制
  • 网站后台演示2023小规模企业所得税税率是多少
  • 常见设计模式讲解
  • 怎么查网站备案服务商房地产新闻动态
  • php做网站主题建设项目一次公示网站
  • 同城外卖系统技术解析:SpringBoot如何赋能区域外卖突围战
  • .NET Framework 4.0.30319:官方下载与常见问题解决指南
  • 池州网站优化有没有网站做字体变形
  • 【论文阅读 | ICCV 2025 | M-SpecGene:面向 RGBT 多光谱视觉的通用基础模型​​】
  • 江苏省省建设厅网站公司的介绍怎么写
  • 专门做二手手机的网站吗网站建设 协议书 doc
  • Kubernetes Headless Service 深度解析 —— 用大白话讲清楚
  • 做网站的软件pageseo策略
  • 怀化冰山涯IT网站建设公司电子商务网站开发背景和意义
  • 免费设立网站企业对比网站