当前位置: 首页 > news >正文

基于PyTorch的Fashion-MNIST图像分类数据集处理与可视化

1. 引言

在本项目中,我们使用 PyTorch 框架加载、处理并可视化了经典的 Fashion-MNIST 图像分类数据集。
本文涵盖了完整的代码、详细注释以及执行后的输出结果,非常适合初学者参考与学习。

2. 环境准备

首先导入必要的库,并设置图片显示为 SVG 格式以提高显示质量。

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()  # 使用SVG格式显示图像

3. 数据集下载与预处理

使用 torchvision.datasets.FashionMNIST 下载数据,并使用 ToTensor 将图片数据从PIL格式转换为 float32 类型,同时将像素值归一化到0-1之间。

# 实例化ToTensor()
trans = transforms.ToTensor()# 下载训练集和测试集
train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)# 查看数据集大小
len(train_data), len(test_data)

输出结果

(60000, 10000)

训练集有6万张图片,测试集有1万张图片。

4. 查看单个样本的尺寸

检查训练集中第一张图片的尺寸信息。

train_data[0][0].shape  # 训练数据第一个样本的维度 (通道数, 宽度, 高度)

 输出结果

torch.Size([1, 28, 28])

可以看到,每张图片是1通道(灰度图),28×28像素。

5. 可视化图像及标签

接下来定义辅助函数:将标签ID转为文字标签,并绘制多张图片。

# 将标签ID映射为文字标签
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 显示图像函数
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

 从训练集中随机抽取18张图片进行展示:

x, y = next(iter(data.DataLoader(train_data, batch_size=18)))
show_images(x.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

效果(示意):

6. 数据迭代性能测试

测试使用DataLoader读取整个训练集需要的时间:

# 创建数据迭代器
train_iter = data.DataLoader(train_data, 256, shuffle=True, num_workers=0, drop_last=False)timer = d2l.Timer()  # 创建计时器
for x, y in train_iter:continue
timer.stop()

 输出结果

4.05924654006958

即:加载完整训练集约需要4秒左右(硬件环境不同时间会略有波动)。

7. 封装数据加载函数

为了方便后续调用,将上面的步骤封装成一个函数 load_data_fashion_mnist

def load_data_fashion_mnist(batch_size, resize=None):# 下载Fashion-MNIST数据集,并将其加载到内存中trans = [transforms.ToTensor()]  # 将图像转换为张量if resize:trans.insert(0, transforms.Resize(resize))  # 若指定,先调整图像大小trans = transforms.Compose(trans)  # 多个变换组合train_data = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)test_data = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)return (data.DataLoader(train_data, batch_size, shuffle=True, num_workers=0, drop_last=False),data.DataLoader(test_data, batch_size, shuffle=False, num_workers=0, drop_last=False))

8. 总结

本文从最基础的环境配置、数据下载与处理、可视化展示,到性能测试与函数封装,完整梳理了Fashion-MNIST数据集在PyTorch中的处理流程。
下一步,我们可以基于此数据集继续训练深度学习模型,如MLP、多层卷积神经网络(CNN)等。

相关文章:

  • Java后端图形验证码的使用
  • [Linux网络_68] 转发 | 路由(Hop by Hop) | IP的分片和组装
  • 当OA闯入元宇宙:打卡、报销和会议的未来狂想
  • 【C++11】包装器:function与bind
  • 【BotSharp框架示例 ——实现聊天机器人,并通过 DeepSeek V3实现 function calling】
  • 【MuJoCo仿真】开源SO100机械臂导入到仿真环境
  • 在 Ubuntu 上离线安装 ClickHouse
  • ShaderToy学习笔记 05.3D旋转
  • 人工智能数学基础(三):微积分初步
  • 深入解析常见排序算法及其 C# 实现
  • 初识Redis · 分布式锁
  • Go 语言中的 `recover()` 函数详解
  • 医疗生态全域智能化:从技术革新到价值重塑的深度探析
  • 基于Spring Boot 3.0、ShardingSphere、PostgreSQL或达梦数据库的分库分表
  • Go语言之路————接口、泛型
  • 在Anolis OS 8上部署Elasticsearch 7.16.1与JDK 11的完整指南
  • 首页数据展示
  • keep-alive具体使用方法
  • C++多线程与锁机制
  • MySQL 在 CentOS 7 环境下的安装教程
  • 违规行为屡禁不止、责任边界模糊不清,法治日报:洞穴探险,谁为安全事故买单?
  • 对谈|李钧鹏、周忆粟:安德鲁·阿伯特过程社会学的魅力
  • 国家核准10台核电新机组,四大核电央企披露新项目进展
  • 现场|西岸美术馆与蓬皮杜启动新五年合作,新展今开幕
  • 北上广深还是小城之春?“五一”想好去哪玩了吗
  • 地下管道密布成难题,道路修整如何破局?