当前位置: 首页 > news >正文

Class3图像分类数据集代码

Class3图像分类数据集代码

%matplotlib inline
import torch
# 导入torchvision库,包含计算机视觉数据集、模型结构、图像变换工具
import torchvision
from torch.utils import data
# 导入图像预处理工具,用于对图像进行变换
from torchvision import transforms
from d2l import torch as d2l
# 使用SVG格式提升绘图清晰度
d2l.use_svg_display()
# 定义图像转换器,将PIL图像或Numpy数组转换为PyTorch张量
trans = transforms.ToTensor()
# 下载训练集的Fashion-MNIST数据,使用trans转换器处理图像
mnist_train = torchvision.datasets.FashionMNIST(root = "../data",train=True,transform=trans,download=True)
# 下载测试集的Fashion-MNIST数据,使用trans转换器处理图像
mnist_test = torchvision.datasets.FashionMNIST(root = "../data",train=False,transform=trans,download=True)
# 查看训练集和测试集的大小
len(mnist_train),len(mnist_test)
# 查看首个位置的元素形状
mnist_train[0][0].shape
# 标签转换函数
def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""# 设置0-9共10类标签text_labels = ['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']# 循环返回每个标签return [text_labels[int(i)] for i in labels]# 图像显示函数
# imgs:图像列表
# num_rows:图像显示的行数
# num_cols:图像显示的列数
# title:标题列表
# scale:缩放图像大小
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):"""绘制图像列表"""# 计算画布尺寸figsize = (num_cols * scale,num_rows * scale)# 创建子图_,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize)axes = axes.flatten()# axes:子图的轴对象数组# imgs:图像数据列表# zip:图像和对应显示位置配对# enumerate:同时获取索引i和内容(ax,imgs)for i,(ax,img) in enumerate(zip(axes,imgs)):# 判断是否为PyTorch的Tensor类型if torch.is_tensor(img):# 是则转换为Numpy再绘图ax.imshow(img.numpy())else:# 不是则直接显示ax.imshow(img)# 取出下方的横坐标和左侧的纵坐标ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)# 设置图像标题if titles:ax.set_title(titles[i])return axes
# 从训练数据集中随机读取一批图像,分别获取图像数据X和标签y
X,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))
# 将图像和对应的标签名可视化
show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y));
# 批量大小为256
batch_size = 256
# 定义加载进程函数
def get_dataloader_workers():"""使用4个进程来读取数据"""return 4
# 创建训练数据加载器
# mnist_train:训练数据集
# batch_size:每批图像的数量
# shuffle=True:打乱数据,防止过拟合
# num_workers=4:使用4个进程加载数据
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
# 创建计时器
timer = d2l.Timer()
# 遍历整个训练集
for X,y in train_iter:continue
# 停止计时并输出结果
f'{timer.stop():.2f} sec'
# 下载数据集并生成训练集和测试集的DataLoader
def load_data_fashion_mnist(batch_size,resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""# 将图像转换为Tensor格式trans = [transforms.ToTensor()]if resize:# 把图像调整为指定大小trans.insert(0,transforms.Resize(resize))# 将多个变换组合,按顺序执行trans = transforms.Compose(trans)# 加载训练集mnist_train = torchvision.datasets.FashionMNIST(root = "../data",train=True,transform=trans,download=True)# 加载测试集mnist_test = torchvision.datasets.FashionMNIST(root = "../data",train=True,transform=trans,download=True)# 创建DataLoaderreturn (data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
# 定义训练集迭代器和测试集迭代器
train_iter,test_iter = load_data_fashion_mnist(32,resize=64)
# 从训练集中取出一批数据
for X,y in train_iter:# X:图像张量# y:标签张量print(X.shape,X.dtype,y.shape,y.dtype)break
http://www.dtcms.com/a/265781.html

相关文章:

  • 数学建模_时间序列
  • CTF Web PHP弱类型与进制绕过(过滤)
  • 【云计算】企业项目 策略授权
  • 网络层:ip协议 与数据链路层
  • C++反射之获取可调用对象的详细信息
  • 《Spring 中上下文传递的那些事儿》Part 2:Web 请求上下文 —— RequestContextHolder 与异步处理
  • 低代码实战训练营教学大纲 (10天)
  • Linux之Socket 编程 UDP
  • 自然光实时渲染~三维场景中的全局光照
  • osg加入实时光照SilverLining 天空和3D 云
  • 租车小程序电动车租赁小程序php方案
  • Flutter 3.29+使用isar构建失败
  • 创客匠人视角:知识变现与创始人 IP 打造的破局之道
  • centos7源码编译安装python3
  • SSM和SpringBoot框架的关系
  • 关于微前端框架micro,子应用设置--el-primary-color失效的问题
  • FPGA从零到一实现FOC(一)之PWM模块设计
  • 火语言 RPA:突破企业自动化瓶颈,释放数字生产力​
  • Linux基本命令篇 —— zip/unzip命令
  • Apache Commons Pool中的GenericObjectPool详解
  • 华为Freebuds 6i新音效,设置后音质敲好!
  • Nginx安全配置漏洞修复实战指南
  • 百度文心智能体平台x小米应用商店:联手打造行业首个智能体与应用市场跨端分发模式
  • React 强大的表单验证库formik之集成Yup、React Hook Form库
  • 使用 Dockerfile 构建基于 .NET9 的跨平台基础镜像
  • 安卓开机自启动方案
  • Kafka生态整合深度解析:构建现代化数据架构的核心枢纽
  • Sklearn安装使用教程
  • 机器人焊接电源节气阀
  • 工程化实践——标准化Eslint、PrettierTS