当前位置: 首页 > news >正文

机器学习与深度学习4:数据集处理Dataset,DataLoader,batch_size

        深度学习中,我们能看到别人的代码中都有一个继承Dataset类的数据集处理过程,这也是深度学习处理数据集的的基础,下面介绍这个数据集的定义和使用:

1、数据集加载

1.1 通用的定义

Bach:表示每次喂给模型的数据

Epoch:表示训练一次完整数据集数据的过程

解释:当一个数据集的大小为10时,设定batch大小为5,那么这个数据就会分为2份,每份大小为5,依次投入到模型中进行训练。训练完所有数据后,就叫做一次迭代,称为epoch

1.2 继承Dataset类

我们继承Dataset类需要实现它的三个方法,代码在文末,与Dataloader代码一起。

init:载入数据

getitem:返回指定位置数据

len:返回数据长度

固定用法如下:

import numpy as np
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):

    def __init__(self):
        #载入数据
        pass
    def __getitem__(self, item):
        #返回相应位置的数据
        pass
    def __len__(self):
        #返回数据长度
        pass

 例如我们有数据集为手写数字识别数据,文件目录如下:

        在pytorch当然最简单的是用内置的MNIST函数,这里不使用该方法,使用Dataset类写一下。

载入数据:由于数据量太大,因此我们载入每个数据的索引,也就是数据的路径

返回相应位置的数据:实现给出index,能返回相应位置的数据。

返回数据长度:返回所有数据的个数。

1.3 代码实现

灰度图转换(任选其一)

任选其一都可以实现,将原始图片转为灰度图:

transforms.Grayscale(num_output_channels=1)#transform实现转换
Image.open(image_path).convert("L")        #image库转换灰度图

因此可以写出Dataset类加载代码 :

transform = transforms.Compose([
    #transforms.Grayscale(num_output_channels=1),  # 转换为单通道灰度图
    transforms.ToTensor()  # 转换为张量
])
class MyDataset(Dataset):
    def __init__(self):
        # 载入数据
        self.images = []
        self.labels = []
        for i in range(10):
            pathX =os.path.join('../mnist_images/train',str(i))
            imageNameList = os.listdir(pathX)
            image = []
            for filename in imageNameList:
                imagePath = os.path.join('../mnist_images/train',str(i),filename)
                image.append(imagePath)
            label = [i] * len(image)
            #label = [i for _ in range(len(image))]列表推导式
            self.images.extend(image)
            self.labels.extend(label)
    def __getitem__(self, item):
        #返回相应位置的数据
        image = Image.open(self.images[item]).convert("L")
        #image = Image.open(self.images[item])
        return transform(image),torch.tensor(self.labels[item])#返回一个元组
    def __len__(self):
        #返回数据长度
        return len(self.images)

1.4 Dataloader批量加载 

        使用Dataset函数处理数据集后,就需要使用Dataloader,它的使用很简单,只有一行:

DataLoader(oneDataset, batch_size=32, shuffle=True, drop_last = False,num_works = 8)

        其中oneDateset表示输入的Dataset对象下面是对其中一些参数的解释:

batach_size 表示一个Batch的大小

shuffle 表示是否打乱数据

drop_last 表示是否舍弃最后数据,若为True那么会舍弃Datasize对batch_size不能整除的部分,也就是如果数据量为10,batch_size为3的话,最后一个数据会被舍弃,如果drop_last为False的话,最后一个数会被保留。也就是最后一个batch_size的大小为1。

num_works 表示使用多少进程加载数据,num_works = 0表示使用主进程加载数据,num_works > 0表示使用多少个子进程加载数据。

        DataLoader返回为一个张量形状为[batch_size, channels, height, width] batch_size表示批量大小,可以是任意正整数,训练模型时,模型输入对该参数batch_size无要求限制,但是后面的三个特征维度[channels, height, width]必须跟模型model定义的输入层数据维度一致。

1.5完整代码:

import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms

transform = transforms.Compose([
    #transforms.Grayscale(num_output_channels=1),  # 转换为单通道灰度图
    transforms.ToTensor()  # 转换为张量
])
class MyDataset(Dataset):
    def __init__(self):
        # 载入数据
        self.images = []
        self.labels = []
        for i in range(10):
            pathX =os.path.join('../mnist_images/train',str(i))
            imageNameList = os.listdir(pathX)
            image = []
            for filename in imageNameList:
                imagePath = os.path.join('../mnist_images/train',str(i),filename)
                image.append(imagePath)
            label = [i] * len(image)
            #label = [i for _ in range(len(image))]列表推导式
            self.images.extend(image)
            self.labels.extend(label)
    def __getitem__(self, item):
        #返回相应位置的数据
        image = Image.open(self.images[item]).convert("L")
        #image = Image.open(self.images[item])
        return transform(image),torch.tensor(self.labels[item])#返回一个元组
    def __len__(self):
        #返回数据长度
        return len(self.images)
def getDataloder():
    oneDataset = MyDataset()
    return DataLoader(oneDataset, batch_size=32, shuffle=True)
if __name__ == '__main__':
    dataloader = getDataloder()
    for images, labels in dataloader:
        print("Batch shape:", images.shape)  # 输出批次形状
        print("Labels:", labels)  # 输出标签
        #print(images[0][0][18])
        break  # 只打印第一个批次

二、 文件下载

文件项目是一个完整的简单神经网络训练手写数字识别,打包下载在这里:点击下载项目

        最后:实现手写数字识别数据集加载方法最简单的是使用pytorch内置MNIST函数实现,仅有一行代码实现上述功能,本文不采用该方法,通过自行实现理解数据集加载原理。

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

相关文章:

  • 动态规划:路径类dp
  • JWT、seesion、cookie、csrf漏洞
  • Git回退文件到指定提交
  • 告别代码Bug,GDB调试工具详解
  • 《Spring Cloud Eureka 高可用集群实战:从零构建 99.99% 可靠性的微服务注册中心》
  • 智能设备定制PCBA板卡快速接入OPC UA系统
  • Elasticsearch-实战案例
  • 反射、枚举以及lambda表达式
  • 多台 Windows 电脑之间共享鼠标和键盘,并支持 剪贴板同步(复制粘贴)
  • 解锁算法密码:多维度探究动态规划,贪心,分治,回溯和分支限界经典算法
  • 个人学习编程(3-27) leetcode刷题
  • JavaScript 调试入门指南
  • 鸿蒙UI开发
  • ​​SenseGlove与Aeon Robotics携手推出HEART项目,助力机器人培训迈向新台阶
  • 【银河麒麟系统常识】命令:uname -m(查看系统架构)
  • FFmpeg —— 在Linux下使用FFmpeg拉取rtsp流解码,留出图像接口供OpenCv处理等(附:源码)
  • Spring Boot使用异步线程池
  • Linux文件搜索与文本过滤全攻略:find、locate、grep深度解析
  • 巧文书-标书产品功能介绍
  • Linux的例行性工作
  • 南宁网站建设南宁/如何制作网站赚钱
  • 去菲律宾做it网站开发/百度关键词推广怎么做
  • nodejs可以做网站吗/今日新闻最新消息50字
  • 牛商网做网站的思路/seo关键词使用
  • 中山手机网站制作哪家好/重庆seo网页优化
  • 欧美做暖网站/被逆冬seo课程欺骗了