当前位置: 首页 > news >正文

深度学习——加载数据

 CIFAR-10图像分类项目

1. 深度学习框架与库导入

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

​知识点详解:​

  1. ​PyTorch​​:Facebook开发的开源深度学习框架,提供张量计算和自动求导功能

  2. ​torchvision​​:PyTorch的计算机视觉库,包含常用数据集、模型架构和图像转换工具

  3. ​transforms​​:提供图像预处理和数据增强功能

  4. ​nn.Module​​:所有神经网络模块的基类,用于构建自定义网络

  5. ​torch.nn.functional​​:包含各种神经网络函数(如激活函数、损失函数)

  6. ​torch.optim​​:提供各种优化算法(如SGD、Adam)

2. 多进程处理与主模块保护

if __name__ == '__main__': # 代码主体

​知识点详解:​

  1. if __name__ == '__main__'​:Python的特殊语法,确保代码只在主模块中执行

  2. ​多进程问题​​:在Windows系统中,PyTorch的多进程数据加载需要此保护,避免递归创建子进程

  3. ​num_workers=0​​:设置数据加载器为单进程模式,避免多进程错误

3. 文件路径处理与验证

import os
root_dir = r"C:\Users\j1378\python pycharm\pythonProject\Clockin\作业"
dataset_folder = os.path.join(root_dir, "cifar-10-batches-py")
if not os.path.exists(dataset_folder):print(f"错误:数据集文件夹不存在 - {dataset_folder}")exit(1)

​知识点详解:​

  1. ​os.path​​:Python标准库中的路径处理模块

  2. ​os.path.join()​​:跨平台安全的路径拼接方法

  3. ​os.path.exists()​​:检查文件或目录是否存在

  4. ​错误处理​​:验证数据集路径,提供清晰的错误信息并优雅退出

4. 数据预处理与标准化

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

​知识点详解:​

  1. ​transforms.Compose()​​:将多个图像变换组合在一起

  2. ​ToTensor()​​:将PIL图像或numpy数组转换为PyTorch张量,并自动缩放像素值到[0,1]范围

  3. ​Normalize()​​:标准化处理,使用均值(0.5,0.5,0.5)和标准差(0.5,0.5,0.5)对每个通道进行标准化

    1. 公式:output = (input - mean) / std

    2. 效果:将像素值从[0,1]范围转换到[-1,1]范围

5. 数据集加载与管理

trainset = torchvision.datasets.CIFAR10(root=root_dir,train=True,download=False,transform=transform
)
testset = torchvision.datasets.CIFAR10(root=root_dir,train=False,download=False,transform=transform
)

​知识点详解:​

  1. ​CIFAR-10数据集​​:包含10个类别的6万张32x32彩色图像(5万训练+1万测试)

  2. ​torchvision.datasets​​:提供常用数据集的便捷访问接口

  3. ​数据划分​​:train=True获取训练集,train=False获取测试集

  4. ​transform参数​​:指定应用于数据的预处理转换

6. 数据加载器(DataLoader)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0
)

​知识点详解:​

  • ​DataLoader​​:提供数据集的小批量迭代访问

  • ​批处理(batch_size)​​:每次迭代返回指定数量的样本(此处为4)

  • ​数据洗牌(shuffle)​​:训练时随机打乱数据顺序,提高模型泛化能力

  • ​num_workers​​:数据加载的子进程数,0表示在主进程中加载

7. 卷积神经网络(CNN)架构

class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 36, 3)self.fc1 = nn.Linear(36 * 6 * 6, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 36 * 6 * 6)x = F.relu(self.fc1(x))x = self.fc2(x)return x

​知识点详解:​

  1. ​nn.Module继承​​:所有自定义网络必须继承自nn.Module

  2. ​卷积层(nn.Conv2d)​​:

    1. conv1: Conv2d(3, 16, 5):输入通道3(RGB),输出通道16,5x5卷积核

    2. conv2: Conv2d(16, 36, 3):输入通道16,输出通道36,3x3卷积核

  3. ​池化层(nn.MaxPool2d)​​:2x2最大池化,步长为2,将特征图尺寸减半

  4. ​全连接层(nn.Linear)​​:

    1. fc1: Linear(36 * 6 * 6, 128):将卷积输出展平后连接到128维隐藏层

    2. fc2: Linear(128, 10):输出10个类别的分数

  5. ​激活函数(F.relu)​​:Rectified Linear Unit,引入非线性

  6. ​张量重塑(x.view())​​:将多维特征张量展平为一维,供全连接层使用

8. 设备管理与GPU加速

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

​知识点详解:​

  1. ​设备检测​​:自动检测可用的计算设备(GPU或CPU)

  2. ​模型迁移​​:使用.to(device)将模型参数和缓冲区移动到指定设备

  3. ​GPU加速​​:利用CUDA进行并行计算,大幅提高训练速度

9. 损失函数与优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

​知识点详解:​

  1. ​交叉熵损失(nn.CrossEntropyLoss)​​:多分类问题的标准损失函数,结合了Softmax和负对数似然

  2. ​随机梯度下降(SGD)​​:基本优化算法

    1. ​lr(学习率)​​:控制参数更新步长

    2. ​momentum(动量)​​:加速SGD在相关方向上的收敛,抑制振荡

10. 模型训练循环

for epoch in range(10):for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()

​知识点详解:​

  • ​epoch​​:完整遍历整个训练集的次数

  • ​前向传播​​:计算模型输出outputs = net(inputs)

  • ​损失计算​​:计算预测值与真实值之间的差异loss = criterion(outputs, labels)

  • ​反向传播​​:loss.backward()自动计算所有参数的梯度

  • ​梯度清零​​:optimizer.zero_grad()清除上一轮的梯度,避免累积

  • ​参数更新​​:optimizer.step()根据梯度更新模型参数

11. 模型评估与准确率计算

with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()

​知识点详解:​

  1. ​torch.no_grad()​​:禁用梯度计算,减少内存消耗,加速推理

  2. ​预测类别​​:torch.max(outputs.data, 1)返回每行最大值的索引(预测类别)

  3. ​准确率计算​​:比较预测值与真实值,统计正确预测的数量

  4. ​逐类别统计​​:分别计算每个类别的准确率,分析模型在不同类别上的表现

12. 可视化与调试

def imshow(img):img = img / 2 + 0.5  # 反归一化npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# 显示样本图像
imshow(torchvision.utils.make_grid(images))

​知识点详解:​

  1. ​图像反归一化​​:将标准化后的图像恢复为可显示格式

  2. ​张量转numpy​​:.numpy()将PyTorch张量转换为NumPy数组

  3. ​维度转换​​:np.transpose()将通道维度从(C,H,W)转换为(H,W,C)

  4. ​matplotlib可视化​​:使用plt.imshow()显示图像

http://www.dtcms.com/a/415960.html

相关文章:

  • 网站不兼容怎么办百度竞价运营
  • 做网站的合作案例影响网站打开速度的因素
  • 网站备案一般要多久网站备案是备什么
  • 自己怎么做网站卖东西建设网站挂广告赚钱
  • 加强网站微信信息编辑队伍建设查询建设工程规范的网站
  • 最炫表白网站html5源码重庆大足网站制作公司哪家专业
  • 网站备案怎么关闭网站wordpress ssl 500
  • 网站建设济南有做的吗政务公开和网站建设工作的建议
  • 【BOOST升压电路】2022-12-8
  • Linux学习笔记(七)--进程状态
  • 网站内搜索关键字怎么做手机app软件
  • 招标网站有哪些网站后台密码忘记了怎么办
  • 第52篇:AI+交通:智能驾驶、交通流优化与智慧物流
  • SQL 优化实战案例:从慢查询到高性能的完整指南
  • 响应式网站做优化好吗wordpress的d8主题
  • MATLAB基于加速遗传算法投影寻踪模型的企业可持续发展能力评价研究
  • 个人网站怎么做收款链接网络营销案例分析报告
  • 做物流网站模块科汛kesioncms网站系统
  • Kafka 合格候选主副本(ELR)在严格 min ISR 约束下提升选主韧性
  • 成都市城乡建设局网站法制教育网站
  • PyQt和PySide中使用Qt Designer
  • 网站建设虚拟云虚拟主机怎么做2个网站
  • 网站建设合同附加协议江门专业做网站
  • 郑州网站制丹东静态管理
  • 网站建设落后发言网站收录作用
  • 虚拟线程的隐形陷阱:Redisson订阅锁超时异常深度剖析
  • 电脑 手机网站建站wordpress主题:yusi v2.0
  • 中材矿山建设有限公司网站wordpress文章关键词描述
  • 云原生架构实战:Kubernetes+ServiceMesh深度解析
  • 重庆网站建设 沛宣企业oa系统免费