当前位置: 首页 > news >正文

大型网站 开发流程上海招聘网最新招聘信息网

大型网站 开发流程,上海招聘网最新招聘信息网,安卓手机优化神器,网站建设推广邮件目标检测数据集转换为图像分类数据集一、简介二、数据集转换代码2.1实现代码2.2代码说明三、图像分类训练代码3.1训练代码3.2代码说明一、简介 图像识别任务按照输出结果的不同通常可以分为三大类:图像分类、目标检测和图像分割,有时出于不同应用的要求…

目标检测数据集转换为图像分类数据集

  • 一、简介
  • 二、数据集转换代码
    • 2.1实现代码
    • 2.2代码说明
  • 三、图像分类训练代码
    • 3.1训练代码
    • 3.2代码说明

一、简介

图像识别任务按照输出结果的不同通常可以分为三大类:图像分类、目标检测和图像分割,有时出于不同应用的要求,需要将目标检测任务转变为图像分类任务,那么由于两种任务数据集格式的不同,就需要进行数据集的转换

在这里插入图片描述

本文以YOLO目标检测数据集为例,给出了数据集转换的实现代码以供参考。

二、数据集转换代码

2.1实现代码

代码如下:

def process_data(args):imgpath, foldername = argsimgpath = imgpath[:-1]imgname = imgpath.split('/')[-1]destimgpath = 'classify_dataset/'+foldername+'/'image = Image.open(imgpath).convert('RGB')image_data = np.array(image)labpath = imgpath.replace('.jpg', '.txt').replace('images', 'labels')with open(labpath, 'r') as f:labdataList = f.readlines()if len(labdataList) == 0:destimgdata = image_data[64:192,64:192]destimg = Image.fromarray(destimgdata)destimg.save(destimgpath+'noobj/'+imgname)else:for idx, labeldata in enumerate(labdataList):parts = labeldata.strip().split()x_center, y_center, w, h = (float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4]))gtx1 = int((x_center - w / 2) * 256)gty1 = int((y_center - h / 2) * 256)gtx2 = int((x_center + w / 2) * 256)gty2 = int((y_center + h / 2) * 256)gtx1 = max(min(gtx1,256),0)gtx2 = max(min(gtx2,256),0)gty1 = max(min(gty1,256),0)gty2 = max(min(gty2,256),0)destimgdata = image_data[gty1:gty2,gtx1:gtx2]try:destimg = Image.fromarray(destimgdata)destimg = destimg.resize((128, 128))except:continuedestimg.save(destimgpath+'obj/'+imgname[:-4]+'_'+str(idx+1)+'.jpg')def gen_classify_dataset():txtpath01 = 'yolodataset/train.txt'txtpath02 = 'yolodataset/val.txt'with open(txtpath01, 'r') as f:imgpathList01 = f.readlines()with open(txtpath02, 'r') as f:imgpathList02 = f.readlines()args_list = [(imgpath, 'train') for imgpath in imgpathList01]with Pool(10) as p:for _ in tqdm(p.imap_unordered(process_data, args_list), total=len(args_list), desc="Processing"):passargs_list = [(imgpath, 'val') for imgpath in imgpathList02]with Pool(10) as p:for _ in tqdm(p.imap_unordered(process_data, args_list), total=len(args_list), desc="Processing"):pass

2.2代码说明

上述代码主函数为gen_classify_dataset,在函数内部通过多核并行的方式调用process_data函数,process_data函数执行具体的功能逻辑,这种方式可以显著提高图像处理速度,具体速度提高细节可见本博客另一篇文章30秒处理1万张图片——图像数据增强的高效执行代码。

process_data函数逻辑如下:

  1. 读取yolo标签文件中的标签内容,如果为空,说明图片中不包含目标,则从图片中间截取128×128大小的区域,保存至noobj文件夹下,noobj作为这一类图片的类别标签
  2. 如果不为空,则循环读取标签中的box信息,根据box截取目标区域,然后调整大小至128×128,保存在obj文件夹下,obj作为这一类图片的类别标签
  3. 上述代码假设只有一类目标,可根据具体场景进行修改。

gen_classify_dataset通过传入不同参数,依次对train、val数据集进行转换。

三、图像分类训练代码

3.1训练代码

下面是一个使用上一节转换完成的数据集的图像分类训练代码,仅供参考:

def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")batch_size = 8epochs = 10data_transform = {"train": transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}train_dataset = datasets.ImageFolder(root="classify_dataset/train",transform=data_transform["train"])train_num = len(train_dataset)nw = 4  # number of workerstrain_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root="classify_dataset/val",transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = model(num_classes=2)net.to(device)loss_function = nn.CrossEntropyLoss()params = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)best_acc = 0.0save_path = './classify.pth'train_steps = len(train_loader)for epoch in range(epochs):net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')

3.2代码说明

上述代码是一个简单的图像分类训练代码,代码逻辑较为清晰,主要定义了batchsize、训练epoch、损失函数、优化器等超参数,但并没有定义模型结构,所以需要进行修改后才可使用。

代码在每epoch训练结束会对会模型进行评估,计算分类准确率这一指标,并且将准确率最高的模型权重保存下来。

http://www.dtcms.com/a/518161.html

相关文章:

  • 您在工信部门备案网站获取的icp备案号老实人做网站
  • 中国建设银行公积金网站首页株洲广告公司找v信hyhyk1做推广好
  • 2016市网站建设总结家装网络平台哪家好
  • 制作网站管理系统弄网站赚钱吗
  • 网站建设管理工作计划网站建设的七个流程步骤
  • wordpress网站被拒登俞润装饰做哪几个网站
  • 站内信息 wordpress深圳商城网站建设
  • 免费学设计的网站微信网站域名备案成功后怎么做
  • 网站建设佰金手指科杰二五广州万户网络科技有限公司
  • 网站建设有关的职位企业信息系统的功能和特点
  • 网站备案很麻烦吗应用下载app排行榜
  • 黑龙江省建设教育协会网站手机wap网站制作
  • 网站开发的可行性网站标题修改
  • 网站链接优化宁波网站建设速成
  • 建设工程图审管理信息系统网站深圳创业板
  • 网站十大品牌施工企业自建自用的工程可以不进行招标是否正确
  • 网站空间商查询做三轨网站犯法吗
  • 在什么网站做调查问卷sem是什么意思中文
  • 惠州网页建站模板服务器运维
  • 贵阳做网站的公司有哪些房屋设计装修网站
  • 一个ip 做2个网站吗广州公司注册流程及需要的材料
  • 网站的组成关于音乐的个人网站
  • 深圳网站建设 设计贝尔网页分析从哪些方面
  • 3d建模视频教学广东网站seo营销
  • 抚顺网站建设2345网址导航怎么卸载win10
  • 该网站正在紧急升级维护中iis如何用ip地址做域名访问网站
  • 如何设计制作企业网站小城镇建设的网站
  • 衡水市网站建设公司wordpress seo模块
  • 福建省法冶建设知识有奖网站上海服装外贸公司排名
  • dw做网站的导航栏怎么做淘宝客自建网站做还是用微信qq做