从头开始复现YOLOv3(二)数据类

wuchangjian2021-10-31 17:15:57编程学习

创建数据类

  • 1 数据组织
    • (1)coco128数据集
    • (2)从训练集中划分出验证集
  • 2 为数据集创建类

1 数据组织

(1)coco128数据集

这里我们使用coco数据集,但coco数据集实在太大,这里我们使用coco128数据集,因为只有128张图片,复制和解压的速度可以很快,coco128数据集下载链接为:https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
在当前目录下建立一个名为data的文件夹,将coco128整个文件夹放入data中,此时项目结构如下图所示:
在这里插入图片描述
打开 data/coco128/labels/train2017/000000000009.txt,可以看到
在这里插入图片描述
每行第一个数字为目标的类别编号,第二个是边框中心点的横坐标,第三个是纵坐标,第四个是边框的宽,第五个是边框的高,中心点坐标和宽高均为归一化后的结果。

进入data/coco128/images/train2017中,可以看到有一个名称为“.DS_Store”的文件,将这个文件删掉,防止其被当成样本
在这里插入图片描述

(2)从训练集中划分出验证集

在data/coco128下新一个名为“划分数据集.py”的Python脚本,该脚本的程序设计思路为:先求有多少图片(使用os.listdir将图片目录中的图片文件名,做成一张列表,求列表的长度),再生成10个随机数,作为索引从图片名够成的列表中获取10张图片的名称,将这10张图片的路径写进一个名为“val_path.txt”的文件中,再将其余的图片写入一个名为“train_path.txt”的文件中。代码如下:

import os
import random
random.seed(10)

# 从图片路径中获取图片名
images_path = r"F:\thesis\yolo3_from_scratch\data\coco128\images\train2017"
images_names = os.listdir(images_path)      # 以列表的方式返回一级子目录
images_num = len(images_names)

# 随机获取10个数,作为图片名列表的索引,进而获得图片名
num_val = 10                # 验证集的数目
idx_img = random.sample(range(0, images_num), num_val)   # 生成不重复的num_val个随机数

# 生成验证集样本的名字,构成一个列表
val_names = [images_names[i] for i in idx_img]
# print(val_names)

# 生成训练集样本的名字,构成一个列表
train_names = [images_names[i] for i in range(images_num) if i not in idx_img]
# print(train_names)

# 将训练集样本的路径写入txt文件
with open(r"F:\thesis\yolo3_from_scratch\data\coco128\train_path.txt", 'w') as f:
    for file_name in train_names:
        f.write(os.path.join(images_path, file_name)+'\n')

# 将验证集样本的路径写入txt文件
with open(r"F:\thesis\yolo3_from_scratch\data\coco128\val_path.txt", 'w') as f:
    for file_name in val_names:
        f.write(os.path.join(images_path, file_name)+'\n')

程序执行后,项目结构为:
在这里插入图片描述
其中,train_path.txt文件内容如下:
在这里插入图片描述
val_path.txt文件内容如下:
在这里插入图片描述
至此,数据组织完毕。

2 为数据集创建类

在utils的目录下,新建一个文件,取名为datasets.py,此时,项目结构变成了
在这里插入图片描述
在datasets.py中,先写入需要的模块:

import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import random
random.seed(0)

当然,上面的模块还不够,但剩余的我们可以在建立类的过程中逐渐加入。

在datasets.py中,建立一个数据集类,该类继承torch.utils.data中的Dataset类,自制的数据集类必须实现三个函数: initlen__和__getitem,分别是初始化类,求长度len(obj),通过索引获得单个样本及其标签。

初始化函数没什么好讲的,其代码如下:

class ListDataset(Dataset):
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
        '''

        :param list_path: 一个txt文件,比如我们前面写的train_path.txt和val_path.txt
        :param img_size: 数据图片要转成成的高
        :param augment: 是否使用数据增强
        :param multiscale: 是否进行多尺度变换(看self.collate_fn就能明白它的作用)
        :param normalized_labels: 标签是否已经归一化,即boundingbox的中心坐标,高宽等是否已经归一化
        '''
        with open(list_path, "r") as file:
            self.img_files = file.readlines()   # 读取txt文件的内容,将样本路径读取出来

        # 标签的路径,可以根据样本的路径来获得,只需要将路径名中的images改成labels,后缀改成txt就行
        self.label_files = [
            path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
            for path in self.img_files
        ]
        self.img_size = img_size    # 图片处理成方形后的高宽(图片在输入模型前要处理成方形)
        self.max_objects = 100      # 一张图片中的最大目标数
        self.augment = augment
        self.multiscale = multiscale
        self.normalized_labels = normalized_labels
        self.min_size = self.img_size - 3 * 32  # 在进行多尺度变换时的最小尺度
        self.max_size = self.img_size + 3 * 32  # 在进行多尺度变换时的最大尺度
        self.batch_count = 0                    # 统计已经遍历了多少个batch
        # TODO max_objects是用来干嘛的

求长度的函数,也比较简单,代码如下:

    def __len__(self):
        return len(self.img_files)

比较难的是getitem函数,要返回指定索引的图片和标签,在函数返回之前,必须将图片和标签做成张量,还涉及到标签是否归一化的问题,所以比较繁琐。

我们先来对图片处理,其代码如下:

    def __getitem__(self, index):
        img_path = self.img_files[index % len(self.img_files)].rstrip()     # 获取图片的路径名

        # Extract image as PyTorch tensor
        img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))    # 读取图片并转化为torch张量
        # mage.open(img_path)读取图片,返回Image对象,不是普通的数组
        # convert('RGB')进行通道转换,因为当图像格式为RGBA时,Image.open(‘xxx.jpg’)读取的格式为RGBA

        # Handle images with less than three channels
        if len(img.shape) != 3:
            img = img.unsqueeze(0)
            img = img.expand((3, img.shape[1:]))
        # 图片有可能是一张灰度图,那么img.shape就是(h, w)
        # unsqueeze(0)之后,就是img.shape就是(1, h, w)
        # img.expand((3, img.shape[1:])) 即为 img.expand((3, h, w))

        _, h, w = img.shape
        h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
        # h_factor, w_factor在后面用来反算目标在图片中的具体坐标,看到本函数的后面,自然能明白
        # 如果已经归一化,那么比例因子就是图片的真实高宽
        # 如果未归一化,那比例因子就是1

        # Pad to square resolution
        img, pad = pad_to_square(img, 0)        
        _, padded_h, padded_w = img.shape
        

这里出现了pad_to_square函数,它通过letter_box算法,将图片变成方形,我们可以在ListDataset类的后面,加上pad_to_square的代码:

import torch.nn.functional as F
def pad_to_square(img, pad_value):
    """
    该函数是将图片扩充成正方形
    :param img: 图片张量
    :param pad_value:   用来填充的值,即左右或者上下的条
    :return:
    """
    c, h, w = img.shape
    dim_diff = np.abs(h - w)
    # (upper / left) padding and (lower / right) padding
    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
    
    # Determine padding 
    pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
    # (0, 0, pad1, pad2)和(pad1, pad2, 0, 0),括号中的四个值,分别表示左右上下
    # 如果h小于w,那么就是在上下填充,否则在左右填充
    # 因为后面使用F.pad函数,第二个参数是pad,它是一个包含四个数的元组
    
    # Add padding
    img = F.pad(img, pad, "constant", value=pad_value)

    return img, pad

我们回到ListDataset类的getitem函数中来,既然对图片进行letter_box处理了,那么对应的标签也要做相应的变换,代码如下:

        label_path = self.label_files[index % len(self.img_files)].rstrip() # 获取标签路径

        targets = None
        if os.path.exists(label_path):
            f = open(label_path, 'r')
            if f.readlines() != []:
                # 有些图片没有目标,但有标签文件,这些标签文件中没有内容
                # 我们这边只处理有内容的标签文件,对于没有内容的标签文件,让targets等于None

                boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))

                # Extract coordinates for unpadded + unscaled image
                # 获取bbox左上角和右下角点在原始图片上的真实坐标
                x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
                y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
                x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
                y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)

                # Adjust for added padding
                # 由于已经被调整成了方形,因此需要加上pad的尺寸
                x1 += pad[0]
                y1 += pad[2]
                x2 += pad[1]
                y2 += pad[3]

                # Returns (x, y, w, h)
                # 求归一化后的中心点坐标和高宽
                boxes[:, 1] = ((x1 + x2) / 2) / padded_w
                boxes[:, 2] = ((y1 + y2) / 2) / padded_h
                boxes[:, 3] *= w_factor / padded_w
                boxes[:, 4] *= h_factor / padded_h

                targets = torch.zeros((len(boxes), 6))
                targets[:, 1:] = boxes          # 后面5列分别是bbox的位置和高宽,然后是分类索引
                # target第0列,根据后面的collate_fn函数,可以看到第0列是图片在batch中的索引

                # Apply augmentations
                # 随机进行水平翻转
                if self.augment:
                    if np.random.random() < 0.5:
                        img, targets = horisontal_flip(img, targets)

            f.close()

        return img_path, img, targets

这里出现了horisontal_flip函数,它是自定定义的水平翻转操作,以此作为数据增强的方法。在utils目录下,新建一个名为“augmentations.py”的脚本,其代码如下:

import torch
def horisontal_flip(images, targets):
    """
    水平翻转
    :param images:图片张量
    :param targets:标签
    :return:
    """
    images = torch.flip(images, [-1])   # 按照指定维度进行翻转,-1表示最后一个维度,即为宽
    targets[:, 2] = 1 - targets[:, 2]   # 中心点横坐标也要随机翻转,因为已经归一化,所以直接用1减就行
    return images, targets

记得在datasets.py中加入下面一句话:

from utils.augmentations import horisontal_flip

此时,项目结构为:
在这里插入图片描述
这里我们可以在“/yolo3_from_scratch”下写一个测试脚本datasets_test.py:

# coding=utf-8
import torch
from utils.datasets import ListDataset

train_path = r"F:\thesis\yolo3_from_scratch\data\coco128\train_path.txt"
dataset = ListDataset(train_path, augment=True, multiscale=True)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True
) 

输出:
在这里插入图片描述
之所以会出现这种情况,是因为两张图片的尺寸不一样,他们在经过letter_box之后,无法堆叠,这里需要我们在ListDataset类中增加一个collate_fn函数,来将两个图形缩放成相同的尺寸。关于collate_fn函数,可以看这篇文章:
https://blog.csdn.net/qq_43391414/article/details/120462055

    def collate_fn(self, batch):
        """
        用于整理数据
        :param batch: 若干次__getitem__函数返回的内容组成的列表
                    假如说batch_size为2,那么batch就是由2个元素组成的列表,
                    每个元素代表一次__getitem__函数的返回结果
                    __getitem__函数的返回值有三个部分组成:img_path, img, targets
                    那么batch的每个元素就是一个元组,包含了img_path, img, targets
        :return:
        """
        paths, imgs, targets = list(zip(*batch))    # zip括号中的参数*开头,表示解压缩
        # 上条命令执行之后,paths, imgs, targets都将成为元组,
        # 以paths为例,上述命令执行后,paths将成为由两个图片路径构成的元组

        # Add sample index to targets 将图片在batch中的索引,加到target的第0个列
        for i, boxes in enumerate(targets):
            if boxes is not None:
                boxes[:, 0] = i     # i表示当前batch中的第i张图片

        # Remove empty placeholder targets 有些图片没有目标,那么它对应的标签就是None
        targets = [boxes for boxes in targets if boxes is not None]  # 保留非None的标签

        targets = torch.cat(targets, 0)     # 标签级联,targets在转化前是一个元组

        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:  # 每10个batch,随机改变一下尺度
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
            # range函数的第三个参数是32,能保证随机获得的新尺寸是32的倍数,
            # 因为是backbone是32倍下采样,如果不是32的倍数,那么卷积核不能完全把图片扫描

        # 将图片缩放到指定尺寸
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])    #
        self.batch_count += 1

        # 图片缩放之后,之所以标签不用改变,是因为标签已经归一化了,所以无需转换

        return paths, imgs, targets

上面的程序段中,出现了list(zip(*batch))和resize,我们先说一下list(zip(*batch)),通过下面的程序段可以了解它的功能:

# 这段代码与YOLOv3无关,仅仅解释list(zip(*batch))实现了什么
a = ('a', 25, 1)
b = ('b', 43, 0)
L= [a, b]
m = zip(*L)     # zip对象
print(m)
print(list(m))

输出:

<zip object at 0x00000000026828C0>
[('a', 'b'), (25, 43), (1, 0)]

collate_fn中出现了resize函数,可以在datasets.py中加上这个函数:

def resize(image, size):
    """
    将图片缩放成指定尺寸
    :param image: 图片张量
    :param size:  指定尺寸,高和宽都是这个值,也就是说,本函数缩放的是正方形
    :return:
    """
    image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
    return image

对测试脚本稍作修改,修改后的datasets_test.py代码如下:

# coding=utf-8
import torch
from utils.datasets import ListDataset

train_path = r"F:\thesis\PyTorch-YOLOv3\data\coco\trainvalno5k.txt"
dataset = ListDataset(train_path, augment=True, multiscale=True)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=dataset.collate_fn,  #
)

for batch_i, (_, imgs, targets) in enumerate(dataloader):
    print(batch_i)
    print("imgs.shape:", imgs.shape)
    print("targets.shape:", targets.shape)
    break

输出:

0
imgs.shape: torch.Size([2, 3, 384, 384])
targets.shape: torch.Size([6, 6])

另外,为了测试数据集类对于没有内容的标签文件的处理,还可以再写一个测试脚本,测试批量导入数据时,能否正确的处理没有目标的标签文件。

# 测试batch为5的时候,第18个batch的targets中,是否有图片索引为1的行
# 1表示第2张图片,即000000000508.jpg的在batch中的索引,而000000000508.jpg中没有目标
# 因为000000000508.jpg就在train_path.txt的第87行,因为这张图片的索引为86,是在第18个batch中
# 可以看看返回的targets的最后几行的图片索引

import torch
from utils.datasets import ListDataset

train_path = r"F:\thesis\yolo3_from_scratch\data\coco128\train_path.txt"
dataset = ListDataset(train_path, augment=True, multiscale=True)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=5,
    shuffle=False,              # 这边不能打乱
    collate_fn=dataset.collate_fn,  #
)

for batch_i, (imgs_path, imgs, targets) in enumerate(dataloader):
    if batch_i<17:
        continue

    print(targets[:3, :])   # 只显示前三行,因为索引为85的图片,只有一个目标
    # 如果第2行以2开头,那说明程序修改的正确,如果以1开头,说明程序出问题
    break

# 通过输出可以看到,targets的第2行是以2开头,说明程序对00....0508.txt处理得当。

输出:
在这里插入图片描述
打开划分数据时建立的文件train_path.txt,可以看到502.jpg,508.jpg,510.jpg三个文件是相邻的
在这里插入图片描述
至此数据集类初步建立起来了,后续我们会根据需要继续加入一下方法和类。

下一章我们讲解一下训练YOLOv3模型。

相关文章

每日一学-002 CSS3 @support详解

参考链接 写法 实用例子...

Found no valid file for the classes .ipynb_checkpoints

删除掉隐藏文件.ipynb_checkpoints即可...

SpringCloud 和Dubbo 的区别

Dubbo关注的领域是Spring  Cloud的一个子集。Dubbo专注与服务治理&#...

发表评论    

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。