当前位置: 首页 > news >正文

mmrotate训练自己的数据(记录)

目录

一、代码下载

二、数据准备

三、训练

四、测试


一、代码下载地址:open-mmlab/mmrotate: OpenMMLab Rotated Object Detection Toolbox and Benchmark

预训练模型下载:

二、数据准备

(1)一般是使用rolabelimg来标注,使用rolabelimg标注生成的是xml文件,但训练的时候使用的是txt文件。xml转txt代码如下:

import os
import xml.etree.ElementTree as ET
import math
import cv2 as cv
import numpy as np


def voc_to_dota(xml_dir, xml_name, img_dir, savedImg_dir):
    txt_name = xml_name[:-4] + '.txt'  # txt文件名字:去掉xml 加上.txt
    txt_path = xml_dir + '/txt_label'  # txt文件目录:在xml目录下创建的txtl_label文件夹
    if not os.path.exists(txt_path):
        os.makedirs(txt_path)
    txt_file = os.path.join(txt_path, txt_name)  # txt完整的含名文件路径

    img_name = xml_name[:-4] + '.jpg'  # 图像名字
    img_path = os.path.join(img_dir, img_name)  # 图像完整路径
    # img = cv.imread(img_path)  # 读取图像
    img = cv.imdecode(np.fromfile(img_path, dtype=np.uint8), 1)
    xml_file = os.path.join(xml_dir, xml_name)
    tree = ET.parse(os.path.join(xml_file))  # 解析xml文件 然后转换为DOTA格式文件
    root = tree.getroot()
    with open(txt_file, "w+", encoding='UTF-8') as out_file:
        # out_file.write('imagesource:null' + '\n' + 'gsd:null' + '\n')
        for obj in root.findall('object'):
            name = obj.find('name').text
            difficult = obj.find('difficult').text
            # print(name, difficult)
            robndbox = obj.find('robndbox')
            cx = float(robndbox.find('cx').text)
            cy = float(robndbox.find('cy').text)
            w = float(robndbox.find('w').text)
            h = float(robndbox.find('h').text)
            angle = float(robndbox.find('angle').text)
            # print(cx, cy, w, h, angle)
            p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
            p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
            p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
            p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)

            # 找最左上角的点
            dict = {p0y: p0x, p1y: p1x, p2y: p2x, p3y: p3x}
            list = find_topLeftPopint(dict)
            # print((list))
            if list[0] == p0x:
                list_xy = [p0x, p0y, p1x, p1y, p2x, p2y, p3x, p3y]
            elif list[0] == p1x:
                list_xy = [p1x, p1y, p2x, p2y, p3x, p3y, p0x, p0y]
            elif list[0] == p2x:
                list_xy = [p2x, p2y, p3x, p3y, p0x, p0y, p1x, p1y]
            else:
                list_xy = [p3x, p3y, p0x, p0y, p1x, p1y, p2x, p2y]

            # 在原图上画矩形 看是否转换正确
            cv.line(img, (int(list_xy[0]), int(list_xy[1])), (int(list_xy[2]), int(list_xy[3])), color=(255, 0, 0),
                    thickness=3)
            cv.line(img, (int(list_xy[2]), int(list_xy[3])), (int(list_xy[4]), int(list_xy[5])), color=(0, 255, 0),
                    thickness=3)
            cv.line(img, (int(list_xy[4]), int(list_xy[5])), (int(list_xy[6]), int(list_xy[7])), color=(0, 0, 255),
                    thickness=2)
            cv.line(img, (int(list_xy[6]), int(list_xy[7])), (int(list_xy[0]), int(list_xy[1])), color=(255, 255, 0),
                    thickness=2)

            data = str(list_xy[0]) + " " + str(list_xy[1]) + " " + str(list_xy[2]) + " " + str(list_xy[3]) + " " + \
                   str(list_xy[4]) + " " + str(list_xy[5]) + " " + str(list_xy[6]) + " " + str(list_xy[7]) + " "
            data = data + name + " " + difficult + "\n"
            out_file.write(data)
        if not os.path.exists(savedImg_dir):
            os.makedirs(savedImg_dir)
        out_img = os.path.join(savedImg_dir, xml_name[:-4] + '.jpg')
        # cv.imwrite(out_img, img)
        cv.imencode(".png", img)[1].tofile(out_img)

def find_topLeftPopint(dict):
    dict_keys = sorted(dict.keys())  # y值
    temp = [dict[dict_keys[0]], dict[dict_keys[1]]]
    minx = min(temp)
    if minx == temp[0]:
        miny = dict_keys[0]
    else:
        miny = dict_keys[1]
    return [minx, miny]


# 转换成四点坐标
def rotatePoint(xc, yc, xp, yp, theta):
    xoff = xp - xc
    yoff = yp - yc
    cosTheta = math.cos(theta)
    sinTheta = math.sin(theta)
    pResx = cosTheta * xoff + sinTheta * yoff
    pResy = - sinTheta * xoff + cosTheta * yoff
    # pRes = (xc + pResx, yc + pResy)
    # 保留一位小数点
    return float(format(xc + pResx, '.1f')), float(format(yc + pResy, '.1f'))
    # return xc + pResx, yc + pResy


import argparse


def parse_args():
    parser = argparse.ArgumentParser(description='数据格式转换')
    parser.add_argument('--xml-dir', default=r'C:\Users\Admin\Desktop\tmp_test\xml', help='original xml file dictionary')
    parser.add_argument('--img-dir', default=r'C:\Users\Admin\Desktop\tmp_test\img', help='original image dictionary')
    parser.add_argument('--outputImg-dir', default=r'C:\Users\Admin\Desktop\tmp_test\train_txt',
                        help='saved image dictionary after dealing ')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    xml_path = args.xml_dir
    xmlFile_list = os.listdir(xml_path)
    print(xmlFile_list)
    for i in range(0, len(xmlFile_list)):
        if ('.xml' in xmlFile_list[i]) or ('.XML' in xmlFile_list[i]):
            voc_to_dota(xml_path, xmlFile_list[i], args.img_dir, args.outputImg_dir)
            print('----------------------------------------{}{}----------------------------------------'
                  .format(xmlFile_list[i], ' has Done!'))
        else:
            print(xmlFile_list[i] + ' is not xml file')

(2)训练数据格式分布如下:images文件夹里是图片,labels文件夹里是对应的txt文件。txt前8个数字就是绘制的旋转框的四个点坐标, 然后是类别名称,最后一个是检测困难程度,0表示不困难

       

三、训练

(1)参数配置:主要配置–config和–work-dir两个。其中–config为训练配置文件,–work-dir为模型和日志保存文件夹。config表示旋转使用哪种模型算法进行训练,这里使用的是rotated_faster_rcnn_r50_fpn_1x_dota_le90.py,可以根据自己的要求进行选择使用

 (2)类别数设置: 根据自己训练的类别数修改

(3)类别名称修改,如果只有一个类别后面需要加上,确保这是一个tuple

 (4)训练数据集路径设置:

 (5)训练epoch设置

(6)训练图片格式修改,默认的代码只支持png格式的图片,在此处进行修改

 (7)预训练模型设置:

(7)配置完之后就可以运行tool/train.py文件了 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/124046.html

相关文章:

  • 使用多进程和 Socket 接收解析数据并推送到 Kafka 的高性能架构
  • 使用js创建img加载阿里云oss图片跨域的问题
  • opencv常用边缘检测算子示例
  • Java 并发-newFixedThreadPool
  • Java——接口扩展
  • 记录一下移动端uView动态表单校验
  • 安装npm install element-plus --save报错
  • OpenCV 图形API(24)图像滤波-----双边滤波函数bilateralFilter()
  • 随机森林与决策树
  • 什么是虚拟线程?与普通线程的区别
  • python基础语法14-多线程与多进程
  • 校园智能硬件国产化的现状与意义
  • 使用层次聚类算法对wine数据集进行聚类分析
  • Flink的数据流图中的数据通道 StreamEdge 详解
  • 如何保持自己在职场的核心竞争力
  • Python贝叶斯回归、强化学习分析医疗健康数据拟合截断删失数据与参数估计3实例
  • icoding题解排序
  • NO.87十六届蓝桥杯备战|动态规划-完全背包|疯狂的采药|Buying Hay|纪念品(C++)
  • x265 编码器中运动搜索 ME 方法对比实验
  • C++基础精讲-03
  • 苍穹外卖总结
  • 【Web API系列】WebSocketStream API 深度实践:构建高吞吐量实时应用的流式通信方案
  • 23种设计模式生活化场景,帮助理解
  • 洛谷刷题Day1——P1706+P1157+P2089+P3654
  • 要查看 FAISS 使用的 OpenMP 版本,需根据安装方式和系统环境采用不同方法。以下是具体步骤和原理分析:
  • [设计模式]发布订阅者模式解耦业务和UI(以Axios拦截器处理响应状态为例)
  • Spring Boot 自动加载流程详解
  • 8.3.5 ToolStripContainer(工具栏容器)控件
  • 线代第四课:行列式的性质
  • 电子元件浸入式冷却