Pytorch应用 小记 第一回:基于ResNet网络的图像定位
Pytorch应用小记 第一回
本次小记,提供了一份基于ResNet网络的图像定位代码。在本回代码中,实现了ResNet网络定位宠物头像(采用的数据集是Oxford-IIIT Pet Dataset)。除了提供代码外,本小记对代码中不容易理解的内容,也进行了讲解。
本代码的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0,d2l的版本是1.0.3
文章目录
Pytorch应用小记 第一回
一、代码
二、小记
1.代码思路
2. glob.glob('dataset/images/*.jpg')
3 .xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]
4.xml = open(r'{}'.format(track)).read()
5. data_tree = etree.HTML(xml)
6.img_width = int(data_tree.xpath('//size/width/text()')[0])
7.label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))
8.num = np.random.permutation(len(images))
data_images = np.array(images)[num]
9. imgs_data = np.repeat(imgs_data[:, :, np.newaxis], 3, axis=2)
10.resnet = torchvision.models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
net_feature = resnet.fc.in_features
11.self.conv_base = nn.Sequential(*list(resnet.children())[:-1])
一、代码
代码如下所示:
import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
from lxml import etree
from matplotlib.patches import Rectangle
import glob
from PIL import Image
from torch.optim import lr_scheduler
from torchvision.models import ResNet101_Weights
import time# 创建输入图像
data_images = glob.glob('dataset/images/*.jpg')
data_xmls = glob.glob('dataset/annotations/xmls/*.xml')
xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]
images = [image for image in data_images ifimage.split('\\')[-1].split('.jpg')[0] in xmls_names]def transform_labels(track):xml = open(r'{}'.format(track)).read()data_tree = etree.HTML(xml)img_width = int(data_tree.xpath('//size/width/text()')[0])img_height = int(data_tree.xpath('//size/height/text()')[0])x_min = int(data_tree.xpath('//bndbox/xmin/text()')[0])y_min = int(data_tree.xpath('//bndbox/ymin/text()')[0])x_max = int(data_tree.xpath('//bndbox/xmax/text()')[0])y_max = int(data_tree.xpath('//bndbox/ymax/text()')[0])return [x_min / img_width, y_min / img_height, x_max / img_width, y_max / img_height]labels = [transform_labels(track) for track in data_xmls]
label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))
num = np.random.permutation(len(images))
data_images = np.array(images)[num]
# 数组或张量的形状调整为二维结构
label_x_min = np.array(label_x_min).astype(np.float32).reshape(-1, 1)[num]
label_y_min = np.array(label_y_min).astype(np.float32).reshape(-1, 1)[num]
label_x_max = np.array(label_x_max).astype(np.float32).reshape(-1, 1)[num]
label_y_max = np.array(label_y_max).astype(np.float32).reshape(-1, 1)[num]
segment = int(len(images) * 0.7)
train_images = data_images[:segment]
x_min_train_label = label_x_min[:segment]
y_min_train_label = label_y_min[:segment]
x_max_train_label = label_x_max[:segment]
y_max_train_label = label_y_max[:segment]test_images = data_images[segment:]
x_min_test_label = label_x_min[segment:]
y_min_test_label = label_y_min[segment:]
x_max_test_label = label_x_max[segment:]
y_max_test_label = label_y_max[segment:]img_scale = 224
transform = transforms.Compose([transforms.Resize((img_scale, img_scale)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 创建Dataset对象
class dataset_Oxford(data.Dataset):def __init__(self, images_in, out1_label, out2_label,out3_label, out4_label, transition):self.images = images_inself.out1_label = out1_labelself.out2_label = out2_labelself.out3_label = out3_labelself.out4_label = out4_labelself.transition = transitiondef __getitem__(self, index):img = self.images[index]out1_label = self.out1_label[index]out2_label = self.out2_label[index]out3_label = self.out3_label[index]out4_label = self.o