当前位置: 首页 > news >正文

【DETR】训练自己的数据集以及YOLO数据集格式(txt)转化成COCO格式(json)

目录

  • 1.DETR介绍
  • 2.数据集处理
  • 3.转化结果可视化
  • 4.数据集训练
    • 4.1修改pth文件
    • 4.2类别参数修改
    • 4.3训练
  • 5.成功运行!
  • 6.参考文献

1.DETR介绍

DETR(Detection with TRansformers)是基于transformer的端对端目标检测,无NMS后处理步骤,无anchor。
代码链接:https://github.com/facebookresearch/detr
在这里插入图片描述

2.数据集处理

DETR需要的数据集格式为coco格式,这里我是用自己的YOLO格式数据集转化成COCO格式,然后进行训练的。
YOLO数据集的组织格式是:
其中images里面分别存放训练集train和验证集val的图片,labels存放训练集train和验证集val的txt标签。
在这里插入图片描述
要转化成适应DETR模型读取的COCO数据集的组织形式是:
其中train2017存放训练集的图片,val2017存放验证集的图片,
annotations文件夹里面存放train和val的json标签。

在这里插入图片描述
下面是转化代码:

  • 需要进行类别映射,每个类别对应的id分别存放在categories里面,这里我没有用classes.txt文件存放,相当于直接把classes.txt里面的类别写出来了。
  • 我的图片是png格式的,如果图片是jpg格式的,将png改成jpg即可。image_name = filename.replace(‘.txt’, ‘.jpg’)
  • 最后修改文件路径,改成自己的路径,这里最后会输出train和val的json文件,图片不会处理,按上述目录组织形式将图片组织起来即可。
  • 生成的文件夹记得改为instances_train2017.json这种样子
import os
import json
from PIL import Image

# 定义类别映射
categories = [
    {"id": 0, "name": "Double hexagonal column"},
    {"id": 1, "name": "Flange nut"},
    {"id": 2, "name": "Hexagon nut"},
    {"id": 3, "name": "Hexagon pillar"},
    {"id": 4, "name": "Hexagon screw"},
    {"id": 5, "name": "Hexagonal steel column"},
    {"id": 6, "name": "Horizontal bubble"},
    {"id": 7, "name": "Keybar"},
    {"id": 8, "name": "Plastic cushion pillar"},
    {"id": 9, "name": "Rectangular nut"},
    {"id": 10, "name": "Round head screw"},
    {"id": 11, "name": "Spring washer"},
    {"id": 12, "name": "T-shaped screw"}
]

def yolo_to_coco(yolo_images_dir, yolo_labels_dir, output_json_path):
    # 初始化 COCO 数据结构
    data = {
        "images": [],
        "annotations": [],
        "categories": categories
    }

    image_id = 1
    annotation_id = 1

    def get_image_size(image_path):
        with Image.open(image_path) as img:
            return img.width, img.height

    # 遍历标签目录
    for filename in os.listdir(yolo_labels_dir):
        if not filename.endswith('.txt'):
            continue  # 只处理 .txt 文件

        image_name = filename.replace('.txt', '.png')# 如果图片是jpg格式的,将png改成jpg即可。
        
        image_path = os.path.join(yolo_images_dir, image_name)

        if not os.path.exists(image_path):
            print(f"⚠️ 警告: 图像 {image_name} 不存在,跳过 {filename}")
            continue

        image_width, image_height = get_image_size(image_path)

        image_info = {
            "id": image_id,
            "width": image_width,
            "height": image_height,
            "file_name": image_name
        }
        data["images"].append(image_info)

        with open(os.path.join(yolo_labels_dir, filename), 'r') as file:
            lines = file.readlines()

        for line in lines:
            parts = line.strip().split()
            if len(parts) != 5:
                print(f"⚠️ 警告: 标签 {filename} 格式错误: {line.strip()}")
                continue

            category_id = int(parts[0])
            x_center = float(parts[1]) * image_width
            y_center = float(parts[2]) * image_height
            bbox_width = float(parts[3]) * image_width
            bbox_height = float(parts[4]) * image_height

            x_min = int(x_center - bbox_width / 2)
            y_min = int(y_center - bbox_height / 2)
            bbox = [x_min, y_min, bbox_width, bbox_height]
            area = bbox_width * bbox_height

            annotation_info = {
                "id": annotation_id,
                "image_id": image_id,
                "category_id": category_id,
                "bbox": bbox,
                "area": area,
                "iscrowd": 0
            }
            data["annotations"].append(annotation_info)
            annotation_id += 1

        image_id += 1

    os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
    with open(output_json_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

    print(f"✅ 转换完成: {output_json_path}")


# 输入路径 (YOLO 格式数据集)
yolo_base_dir = "/home/yu/Yolov8/ultralytics-main/mydata0"
yolo_train_images = os.path.join(yolo_base_dir, "images/train")
yolo_train_labels = os.path.join(yolo_base_dir, "labels/train")
yolo_val_images = os.path.join(yolo_base_dir, "images/val")
yolo_val_labels = os.path.join(yolo_base_dir, "labels/val")

# 输出路径 (COCO 格式)
coco_base_dir = "/home/yu/Yolov8/ultralytics-main/mydata0_coco"
coco_train_json = os.path.join(coco_base_dir, "annotations/instances_train.json")
coco_val_json = os.path.join(coco_base_dir, "annotations/instances_val.json")

# 运行转换
yolo_to_coco(yolo_train_images, yolo_train_labels, coco_train_json)
yolo_to_coco(yolo_val_images, yolo_val_labels, coco_val_json)

3.转化结果可视化

COCO数据集JSON文件格式分为以下几个字段。

{
    "info": info, # dict
     "licenses": [license], # list ,内部是dict
     "images": [image], # list ,内部是dict
     "annotations": [annotation], # list ,内部是dict
     "categories": # list ,内部是dict
 }

可以运行以下脚本查看转化后的标签是否与图片目标对应:

  • 修改代码的json_path和img_path,json_path是标签对应的路径,img_path是图像对应的路径
'''
该代码的功能是:读取图像以及对应bbox的信息
'''
import os
from pycocotools.coco import COCO
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

json_path = "/home/yu/Yolov8/ultralytics-main/mydata0_coco/annotations/instances_val.json"
img_path = ("/home/yu/Yolov8/ultralytics-main/mydata0_coco/images/val")

# load coco data
coco = COCO(annotation_file=json_path)

# get all image index info
ids = list(sorted(coco.imgs.keys()))
print("number of images: {}".format(len(ids)))

# get all coco class labels
coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])

# 遍历前三张图像
for img_id in ids[:3]:
    # 获取对应图像id的所有annotations idx信息
    ann_ids = coco.getAnnIds(imgIds=img_id)

    # 根据annotations idx信息获取所有标注信息
    targets = coco.loadAnns(ann_ids)

    # get image file name
    path = coco.loadImgs(img_id)[0]['file_name']

    # read image
    img = Image.open(os.path.join(img_path, path)).convert('RGB')
    draw = ImageDraw.Draw(img)
    # draw box to image
    for target in targets:
        x, y, w, h = target["bbox"]
        x1, y1, x2, y2 = x, y, int(x + w), int(y + h)
        draw.rectangle((x1, y1, x2, y2))
        draw.text((x1, y1), coco_classes[target["category_id"]])

    # show image
    plt.imshow(img)
    plt.show()

运行该代码,你将会看到你的标签是否对应:
如果目标没有边界框则说明你转化的json不对!
在这里插入图片描述
在这里插入图片描述

4.数据集训练

4.1修改pth文件

将它的pth文件改一下,因为他是用的coco数据集,而我们只需要训练自己的数据集,就是下图这个文件,这是它原本的
在这里插入图片描述
新建一个.py文件,运行下面代码,就会生成一个你数据集所需要的物体数目的pth,记得改类别数!。

import torch
pretrained_weights  = torch.load('detr-r50-e632da11.pth')

num_class = 14 #这里是你的物体数+1,因为背景也算一个
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights, "detr-r50_%d.pth"%num_class

这是我们生成的。
在这里插入图片描述

4.2类别参数修改

修改models/detr.py文件,build()函数中,可以将红框部分的代码都注释掉,直接设置num_classes为自己的类别数+1
因为我的类别数是13,所以我这里num_classes=14
在这里插入图片描述

4.3训练

修改main.py文件的epochs、lr、batch_size等训练参数:
以下这些参数都在get_args_parser()函数里面。
在这里插入图片描述

修改自己的数据集路径:
在这里插入图片描述
设置输出路径:
在这里插入图片描述

修改resume为自己的预训练权重文件路径
这里就是你刚才运行脚本生成的pth文件的路径:
在这里插入图片描述
运行main.py文件
或者可以通过命令行运行:

python main.py --dataset_file "coco" --coco_path "/home/yu/Yolov8/ultralytics-main/mydata0_coco" --epoch 300 --lr=1e-4 --batch_size=8 --num_workers=4 --output_dir="outputs" --resume="detr_r50_14.pth"

5.成功运行!

在这里插入图片描述

6.参考文献

1.【DETR】训练自己的数据集-实践笔记
2. yolo数据集格式(txt)转coco格式,方便mmyolo转标签格式
3. windows10复现DEtection TRansformers(DETR)并实现自己的数据集

相关文章:

  • 计算机视觉总结
  • Golang开发棋牌游戏中的坑
  • fastapi下载图片
  • 嵌入式八股RTOS与Linux--hea4与TLSF篇
  • 《基于深度学习的指纹识别智能门禁系统》开题报告
  • Spring IOC核心详解:掌握控制反转与依赖注入
  • (四)---四元数的基础知识-(定义)-(乘法)-(逆)-(退化到二维复平面)-(四元数乘法的导数)
  • 【Spring IoC DI】深入解析 IoC & DI :Spring框架的核心设计思想和 IoC 与 DI 的思想和解耦优势
  • IDEA 快捷键ctrl+shift+f 无法全局搜索内容的问题及解决办法
  • MySQL表的增加、查询、修改、删除的基础操作
  • BEVFormer报错(预测场景与真值场景的sample_token不匹配)
  • springCloud集成tdengine(原生和mapper方式) 其一
  • Springboot之RequestAttributes学习笔记
  • 使用selenium来获取数据集
  • 在Ubuntu 22.04 中安装Docker的详细指南
  • elasticsearch 通用笔记
  • windows 安装 Elasticsearch
  • 六、GPIO中断控制器(1)—— pcf8575
  • CSRF跨站请求伪造(Cross - Site Request Forgery)
  • 蓝桥杯 劲舞团
  • 李在明正式登记参选下届韩国总统
  • 华泰柏瑞基金总经理韩勇因工作调整卸任,董事长贾波代为履职
  • 三大猪企4月生猪销量同比均增长,销售均价同比小幅下降
  • 王受文已任全国工商联党组成员
  • 前4个月我国货物贸易进出口同比增长2.4%,增速较一季度加快1.1个百分点
  • 保证断电、碰撞等事故中车门系统能够开启!汽车车门把手将迎来强制性国家标准