目标检测数据集转换为图像分类数据集
目标检测数据集转换为图像分类数据集
- 一、简介
- 二、数据集转换代码
- 2.1实现代码
- 2.2代码说明
- 三、图像分类训练代码
- 3.1训练代码
- 3.2代码说明
一、简介
图像识别任务按照输出结果的不同通常可以分为三大类:图像分类、目标检测和图像分割,有时出于不同应用的要求,需要将目标检测任务转变为图像分类任务,那么由于两种任务数据集格式的不同,就需要进行数据集的转换。
本文以YOLO目标检测数据集为例,给出了数据集转换的实现代码以供参考。
二、数据集转换代码
2.1实现代码
代码如下:
def process_data(args):imgpath, foldername = argsimgpath = imgpath[:-1]imgname = imgpath.split('/')[-1]destimgpath = 'classify_dataset/'+foldername+'/'image = Image.open(imgpath).convert('RGB')image_data = np.array(image)labpath = imgpath.replace('.jpg', '.txt').replace('images', 'labels')with open(labpath, 'r') as f:labdataList = f.readlines()if len(labdataList) == 0:destimgdata = image_data[64:192,64:192]destimg = Image.fromarray(destimgdata)destimg.save(destimgpath+'noobj/'+imgname)else:for idx, labeldata in enumerate(labdataList):parts = labeldata.strip().split()x_center, y_center, w, h = (float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4]))gtx1 = int((x_center - w / 2) * 256)gty1 = int((y_center - h / 2) * 256)gtx2 = int((x_center + w / 2) * 256)gty2 = int((y_center + h / 2) * 256)gtx1 = max(min(gtx1,256),0)gtx2 = max(min(gtx2,256),0)gty1 = max(min(gty1,256),0)gty2 = max(min(gty2,256),0)destimgdata = image_data[gty1:gty2,gtx1:gtx2]try:destimg = Image.fromarray(destimgdata)destimg = destimg.resize((128, 128))except:continuedestimg.save(destimgpath+'obj/'+imgname[:-4]+'_'+str(idx+1)+'.jpg')def gen_classify_dataset():txtpath01 = 'yolodataset/train.txt'txtpath02 = 'yolodataset/val.txt'with open(txtpath01, 'r') as f:imgpathList01 = f.readlines()with open(txtpath02, 'r') as f:imgpathList02 = f.readlines()args_list = [(imgpath, 'train') for imgpath in imgpathList01]with Pool(10) as p:for _ in tqdm(p.imap_unordered(process_data, args_list), total=len(args_list), desc="Processing"):passargs_list = [(imgpath, 'val') for imgpath in imgpathList02]with Pool(10) as p:for _ in tqdm(p.imap_unordered(process_data, args_list), total=len(args_list), desc="Processing"):pass
2.2代码说明
上述代码主函数为gen_classify_dataset,在函数内部通过多核并行的方式调用process_data函数,process_data函数执行具体的功能逻辑,这种方式可以显著提高图像处理速度,具体速度提高细节可见本博客另一篇文章30秒处理1万张图片——图像数据增强的高效执行代码。
process_data函数逻辑如下:
- 读取yolo标签文件中的标签内容,如果为空,说明图片中不包含目标,则从图片中间截取128×128大小的区域,保存至noobj文件夹下,noobj作为这一类图片的类别标签;
- 如果不为空,则循环读取标签中的box信息,根据box截取目标区域,然后调整大小至128×128,保存在obj文件夹下,obj作为这一类图片的类别标签;
- 上述代码假设只有一类目标,可根据具体场景进行修改。
gen_classify_dataset通过传入不同参数,依次对train、val数据集进行转换。
三、图像分类训练代码
3.1训练代码
下面是一个使用上一节转换完成的数据集的图像分类训练代码,仅供参考:
def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")batch_size = 8epochs = 10data_transform = {"train": transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}train_dataset = datasets.ImageFolder(root="classify_dataset/train",transform=data_transform["train"])train_num = len(train_dataset)nw = 4 # number of workerstrain_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root="classify_dataset/val",transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = model(num_classes=2)net.to(device)loss_function = nn.CrossEntropyLoss()params = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)best_acc = 0.0save_path = './classify.pth'train_steps = len(train_loader)for epoch in range(epochs):net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')
3.2代码说明
上述代码是一个简单的图像分类训练代码,代码逻辑较为清晰,主要定义了batchsize、训练epoch、损失函数、优化器等超参数,但并没有定义模型结构,所以需要进行修改后才可使用。
代码在每epoch训练结束会对会模型进行评估,计算分类准确率这一指标,并且将准确率最高的模型权重保存下来。