当前位置: 首页 > news >正文

目标检测数据集转换为图像分类数据集

目标检测数据集转换为图像分类数据集

  • 一、简介
  • 二、数据集转换代码
    • 2.1实现代码
    • 2.2代码说明
  • 三、图像分类训练代码
    • 3.1训练代码
    • 3.2代码说明

一、简介

图像识别任务按照输出结果的不同通常可以分为三大类:图像分类、目标检测和图像分割,有时出于不同应用的要求,需要将目标检测任务转变为图像分类任务,那么由于两种任务数据集格式的不同,就需要进行数据集的转换

在这里插入图片描述

本文以YOLO目标检测数据集为例,给出了数据集转换的实现代码以供参考。

二、数据集转换代码

2.1实现代码

代码如下:

def process_data(args):imgpath, foldername = argsimgpath = imgpath[:-1]imgname = imgpath.split('/')[-1]destimgpath = 'classify_dataset/'+foldername+'/'image = Image.open(imgpath).convert('RGB')image_data = np.array(image)labpath = imgpath.replace('.jpg', '.txt').replace('images', 'labels')with open(labpath, 'r') as f:labdataList = f.readlines()if len(labdataList) == 0:destimgdata = image_data[64:192,64:192]destimg = Image.fromarray(destimgdata)destimg.save(destimgpath+'noobj/'+imgname)else:for idx, labeldata in enumerate(labdataList):parts = labeldata.strip().split()x_center, y_center, w, h = (float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4]))gtx1 = int((x_center - w / 2) * 256)gty1 = int((y_center - h / 2) * 256)gtx2 = int((x_center + w / 2) * 256)gty2 = int((y_center + h / 2) * 256)gtx1 = max(min(gtx1,256),0)gtx2 = max(min(gtx2,256),0)gty1 = max(min(gty1,256),0)gty2 = max(min(gty2,256),0)destimgdata = image_data[gty1:gty2,gtx1:gtx2]try:destimg = Image.fromarray(destimgdata)destimg = destimg.resize((128, 128))except:continuedestimg.save(destimgpath+'obj/'+imgname[:-4]+'_'+str(idx+1)+'.jpg')def gen_classify_dataset():txtpath01 = 'yolodataset/train.txt'txtpath02 = 'yolodataset/val.txt'with open(txtpath01, 'r') as f:imgpathList01 = f.readlines()with open(txtpath02, 'r') as f:imgpathList02 = f.readlines()args_list = [(imgpath, 'train') for imgpath in imgpathList01]with Pool(10) as p:for _ in tqdm(p.imap_unordered(process_data, args_list), total=len(args_list), desc="Processing"):passargs_list = [(imgpath, 'val') for imgpath in imgpathList02]with Pool(10) as p:for _ in tqdm(p.imap_unordered(process_data, args_list), total=len(args_list), desc="Processing"):pass

2.2代码说明

上述代码主函数为gen_classify_dataset,在函数内部通过多核并行的方式调用process_data函数,process_data函数执行具体的功能逻辑,这种方式可以显著提高图像处理速度,具体速度提高细节可见本博客另一篇文章30秒处理1万张图片——图像数据增强的高效执行代码。

process_data函数逻辑如下:

  1. 读取yolo标签文件中的标签内容,如果为空,说明图片中不包含目标,则从图片中间截取128×128大小的区域,保存至noobj文件夹下,noobj作为这一类图片的类别标签
  2. 如果不为空,则循环读取标签中的box信息,根据box截取目标区域,然后调整大小至128×128,保存在obj文件夹下,obj作为这一类图片的类别标签
  3. 上述代码假设只有一类目标,可根据具体场景进行修改。

gen_classify_dataset通过传入不同参数,依次对train、val数据集进行转换。

三、图像分类训练代码

3.1训练代码

下面是一个使用上一节转换完成的数据集的图像分类训练代码,仅供参考:

def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")batch_size = 8epochs = 10data_transform = {"train": transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}train_dataset = datasets.ImageFolder(root="classify_dataset/train",transform=data_transform["train"])train_num = len(train_dataset)nw = 4  # number of workerstrain_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root="classify_dataset/val",transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = model(num_classes=2)net.to(device)loss_function = nn.CrossEntropyLoss()params = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)best_acc = 0.0save_path = './classify.pth'train_steps = len(train_loader)for epoch in range(epochs):net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')

3.2代码说明

上述代码是一个简单的图像分类训练代码,代码逻辑较为清晰,主要定义了batchsize、训练epoch、损失函数、优化器等超参数,但并没有定义模型结构,所以需要进行修改后才可使用。

代码在每epoch训练结束会对会模型进行评估,计算分类准确率这一指标,并且将准确率最高的模型权重保存下来。

http://www.dtcms.com/a/344872.html

相关文章:

  • Pandas中的SettingWithCopyWarning警告出现原因及解决方法
  • 共享内存详细解释
  • 前端在WebSocket中加入Token的方法
  • 12-Linux系统用户管理及基础权限
  • 塞尔达传说 王国之泪 PC/手机双端 免安装中文版
  • celery
  • C语言翻译环境作业
  • 大学校园安消一体化平台——多警合一实现智能联动与网格化管理
  • 【链表 - LeetCode】19. 删除链表的倒数第 N 个结点
  • Android.mk 基础
  • Electron 核心 API 全解析:从基础到实战场景
  • 从零开始搭 Linux 环境:VMware 下 CentOS 7 的安装与配置全流程(附图解)
  • openstack的novnc兼容问题
  • 【日常学习】2025-8-20 框架中控件子类实例化设计
  • FPGA学习笔记——简单的IIC读写EEPROM
  • LeetCode 3195.包含所有 1 的最小矩形面积 I:简单题-求长方形四个范围
  • 化工生产场景下设备状态监测与智能润滑预测性维护路径
  • 校园作品互评管理移动端的设计与实现
  • Boost库中boost::random::normal_distribution(正态分布)详解和实战示例
  • 腾讯云EdgeOne安全防护:快速上手,全面抵御Web攻击
  • 如何优雅的监听dom的变化(尺寸)
  • php apache无法接收到Authorization header
  • JDK17 升级避坑指南:技术原理与解决方案详解
  • 【学习记录】structuredClone,URLSearchParams,groupBy
  • 【大语言模型 14】Transformer权重初始化策略:从Xavier到GPT的参数初始化演进之路
  • 网络编程8.22
  • Python面试常考函数
  • 技术分析 剖析一个利用FTP快捷方式与批处理混淆的钓鱼攻击
  • RSS与今日头条技术对比分析
  • Unreal Engine UObject