当前位置: 首页 > news >正文

使用 PyTorch 的 torchvision 库加载 CIFAR-10 数据集

CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

import torchvision
import torchvision.transforms as transforms# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 自动下载训练集
trainset = torchvision.datasets.CIFAR10(root='./data',  # 数据保存路径train=True,download=True,  # 设置为True自动下载transform=transform
)# 自动下载测试集
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform
)

1. 导入必要的库

import torchvision

import torchvision.transforms as transforms

  1. torchvision:PyTorch 的视觉库,提供常用数据集、模型架构和图像转换工具。
  2. transforms:用于图像预处理的模块,如缩放、归一化等。

2. 定义数据预处理流程

transform = transforms.Compose([

    transforms.ToTensor(),

    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

])

  1. transforms.Compose:将多个预处理操作按顺序组合。
  2. transforms.ToTensor()
    1. 将 PIL 图像或 NumPy 数组(H×W×C,范围 0-255)转换为 PyTorch 张量(C×H×W,范围 0.0-0)。
  3. transforms.Normalize(mean, std)
    1. 对每个通道进行归一化:output = (input - mean) / std
    2. 这里mean=(0.5, 0.5, 0.5)std=(0.5, 0.5, 0.5)将像素值从[0.0, 1.0]映射到[-1.0, 1.0](例如,0.0→-1.0,1.0→1.0)。

3. 下载并加载训练集

trainset = torchvision.datasets.CIFAR10(

    root='./data',  # 数据保存路径

    train=True,     # True表示训练集(50,000张)

    download=True,  # 自动下载(如果数据不存在)

    transform=transform  # 应用预处理

)

  1. torchvision.datasets.CIFAR10:CIFAR-10 数据集类,包含 10 个类别(如飞机、汽车、鸟类等)的 60,000 张 32×32 彩色图像。
  2. 参数说明
    1. root='./data':数据将下载到当前目录的data文件夹中。
    2. train=True:加载训练集(50,000 张);若为False则加载测试集(10,000 张)。
    3. download=True:若数据不存在,自动从互联网下载(约 170MB)。
    4. transform=transform:对图像应用之前定义的预处理(转为张量并归一化)。

4. 下载并加载测试集

testset = torchvision.datasets.CIFAR10(

    root='./data',  # 与训练集路径一致

    train=False,    # 加载测试集

    download=True,  # 自动下载

    transform=transform  # 应用相同的预处理

)

  1. 测试集与训练集结构相同,但用于模型评估,不参与训练。

5. 数据验证与使用

下载完成后,数据将存储在./data/cifar-10-batches-py目录中。你可以:

  1. 查看数据集大小

print(len(trainset))  # 输出: 50000

print(len(testset))   # 输出: 10000

  1. 访问单个样本

image, label = trainset[0]  # 获取第一张图像及其标签

print(image.shape)  # 输出: (3, 32, 32)

print(label)        # 输出: 6(对应类别索引)

  1. 使用数据加载器批量处理数据

from torch.utils.data import DataLoader

trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

testloader = DataLoader(testset, batch_size=32, shuffle=False)

注意事项

  1. 下载路径
    1. 若指定路径(如./data)已存在 CIFAR-10 数据,download=True不会重复下载。
    2. 若路径错误或无写入权限,会抛出异常(如PermissionError)。
  2. 网络问题
    1. 首次下载需联网,可能需要几分钟。若下载中断,可删除./data目录后重新运行。
  3. 数据预处理
    1. 归一化参数meanstd通常根据数据集的统计特性设定。对于 CIFAR-10,常用(0.5, 0.5, 0.5)进行简单归一化。
    2. 若需要更精确的归一化,可计算数据集的真实均值和标准差(如mean=[0.4914, 0.4822, 0.4465]std=[0.2470, 0.2435, 0.2616])。

扩展应用

加载数据后,可用于训练 CNN 模型(如之前创建的SimpleCNN):

# 假设model已定义

from torch import nn, optim

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练循环

for epoch in range(5):  # 训练5个轮次

    for inputs, labels in trainloader:

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = criterion(outputs, labels)

        loss.backward()

        optimizer.step()

    print(f"Epoch {epoch+1} completed")

http://www.dtcms.com/a/291213.html

相关文章:

  • python 中if/elif/else 是如何构建程序逻辑的?
  • 【初识数据结构】CS61B中的最小生成树问题
  • LLaMA-Factory 微调可配置的模型基本参数
  • jcmd用法总结
  • 完整的 SquareStudio 注册登录功能实现方案:已经烧录到开发板正常使用
  • 83、形式化方法
  • Unity VR多人手术系统恢复3:Agora语音通讯系统问题解决全记录
  • 【CAN】01.CAN简介硬件电路
  • 视网膜分支静脉阻塞(BRVO)及抗VEGF治疗的多模态影像学研究
  • 同步与异步?从一个卡顿的Java服务说起
  • 文字检测到文字识别
  • 如何用 Z.ai 生成PPT,一句话生成整套演示文档
  • 自反馈机制(Self-Feedback)在大模型中的原理、演进与应用
  • 【PTA数据结构 | C语言版】哥尼斯堡的“七桥问题”
  • 【ROS1】07-话题通信中使用自定义msg
  • (9)机器学习小白入门 YOLOv:YOLOv8-cls 技术解析与代码实现
  • 选择排序 冒泡排序
  • LinkedList与链表(单向)(Java实现)
  • android studio 远程库编译报错无法访问远程库如何解决
  • 算法提升之字符串回文问题-(马拉车算法)
  • Java基础教程(011):面向对象中的构造方法
  • 模拟高负载测试脚本
  • Flink框架:keyBy实现按键逻辑分区
  • 250kHz采样率下多信号参数设置
  • mysql-5.7 Linux安装教程
  • 无人机报警器技术要点与捕捉方式
  • Anaconda 路径精简后暴露 python 及工具到环境变量的配置记录 [二]
  • Linux学习之Linux系统权限
  • scratch音乐会开幕倒计时 2025年6月中国电子学会图形化编程 少儿编程 scratch编程等级考试一级真题和答案解析
  • Git核心功能简要学习