当前位置: 首页 > news >正文

Pytorch应用 小记 第一回:基于ResNet网络的图像定位

Pytorch应用小记 第一回

本次小记,提供了一份基于ResNet网络的图像定位代码。在本回代码中,实现了ResNet网络定位宠物头像(采用的数据集是Oxford-IIIT Pet Dataset)。除了提供代码外,本小记对代码中不容易理解的内容,也进行了讲解。
本代码的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0,d2l的版本是1.0.3


文章目录

Pytorch应用小记 第一回

一、代码

二、小记

1.代码思路

2. glob.glob('dataset/images/*.jpg')

3 .xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]

4.xml = open(r'{}'.format(track)).read()

5. data_tree = etree.HTML(xml)

6.img_width = int(data_tree.xpath('//size/width/text()')[0])

7.label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))

8.num = np.random.permutation(len(images))

   data_images = np.array(images)[num]

9. imgs_data = np.repeat(imgs_data[:, :, np.newaxis], 3, axis=2)

10.resnet = torchvision.models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)

net_feature = resnet.fc.in_features

11.self.conv_base = nn.Sequential(*list(resnet.children())[:-1])


一、代码

代码如下所示:

import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
from lxml import etree
from matplotlib.patches import Rectangle
import glob
from PIL import Image
from torch.optim import lr_scheduler
from torchvision.models import ResNet101_Weights
import time# 创建输入图像
data_images = glob.glob('dataset/images/*.jpg')
data_xmls = glob.glob('dataset/annotations/xmls/*.xml')
xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]
images = [image for image in data_images ifimage.split('\\')[-1].split('.jpg')[0] in xmls_names]def transform_labels(track):xml = open(r'{}'.format(track)).read()data_tree = etree.HTML(xml)img_width = int(data_tree.xpath('//size/width/text()')[0])img_height = int(data_tree.xpath('//size/height/text()')[0])x_min = int(data_tree.xpath('//bndbox/xmin/text()')[0])y_min = int(data_tree.xpath('//bndbox/ymin/text()')[0])x_max = int(data_tree.xpath('//bndbox/xmax/text()')[0])y_max = int(data_tree.xpath('//bndbox/ymax/text()')[0])return [x_min / img_width, y_min / img_height, x_max / img_width, y_max / img_height]labels = [transform_labels(track) for track in data_xmls]
label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))
num = np.random.permutation(len(images))
data_images = np.array(images)[num]
# 数组或张量的形状调整为二维结构
label_x_min = np.array(label_x_min).astype(np.float32).reshape(-1, 1)[num]
label_y_min = np.array(label_y_min).astype(np.float32).reshape(-1, 1)[num]
label_x_max = np.array(label_x_max).astype(np.float32).reshape(-1, 1)[num]
label_y_max = np.array(label_y_max).astype(np.float32).reshape(-1, 1)[num]
segment = int(len(images) * 0.7)
train_images = data_images[:segment]
x_min_train_label = label_x_min[:segment]
y_min_train_label = label_y_min[:segment]
x_max_train_label = label_x_max[:segment]
y_max_train_label = label_y_max[:segment]test_images = data_images[segment:]
x_min_test_label = label_x_min[segment:]
y_min_test_label = label_y_min[segment:]
x_max_test_label = label_x_max[segment:]
y_max_test_label = label_y_max[segment:]img_scale = 224
transform = transforms.Compose([transforms.Resize((img_scale, img_scale)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 创建Dataset对象
class dataset_Oxford(data.Dataset):def __init__(self, images_in, out1_label, out2_label,out3_label, out4_label, transition):self.images = images_inself.out1_label = out1_labelself.out2_label = out2_labelself.out3_label = out3_labelself.out4_label = out4_labelself.transition = transitiondef __getitem__(self, index):img = self.images[index]out1_label = self.out1_label[index]out2_label = self.out2_label[index]out3_label = self.out3_label[index]out4_label = self.o

相关文章:

  • 汇编语言的温度魔法:单总线温度采集与显示的奇幻之旅
  • Python-函数
  • 备战菊厂笔试3
  • C# 使用 WinUI 3 项目模板创建桌面应用程序
  • C++GO语言微服务之图片、短信验证码生成及存储
  • Ajax基础
  • .Net HttpClient 管理客户端(初始化与生命周期管理)
  • 202534 | KafKa简介+应用场景+集群搭建+快速入门
  • kafka的安装及简单使用
  • [sklearn机器学习概述]机器学习-part3
  • 运算符与表达式 -《Go语言实战指南》
  • Scala与Go的异同教程
  • 【计算机视觉】OpenCV项目实战:基于OpenCV的图像分割技术深度解析与实践指南
  • 5.1 神经网络: 层和块
  • 电子电器架构 --- 车载以太网拓扑
  • k8s删除pv和pvc后,vg存储没释放分析
  • word换行符和段落标记
  • 2024年AI发展趋势全面解析:从多模态到AGI的突破
  • Python 从 SQLite 数据库中批量提取图像数据
  • 深拷贝与浅拷贝:理解 Python 中的对象复制机制
  • 同济大学原常务副校长、著名隧道及地下工程专家李永盛逝世
  • 优秀“博主”在上海杨浦购房最高补贴200万元,有何条件?
  • 教育部答澎湃:2025世界数字教育大会将发布系列重磅成果
  • 赵作海因病离世,妻子李素兰希望过平静生活
  • 心相印回应官方旗舰店客服辱骂消费者:正排查
  • 视频丨习近平同普京会谈:共同弘扬正确二战史观,维护联合国权威和地位