当前位置: 首页 > news >正文

Pytorch应用 小记 第一回:基于ResNet网络的图像定位

Pytorch应用小记 第一回

本次小记,提供了一份基于ResNet网络的图像定位代码。在本回代码中,实现了ResNet网络定位宠物头像(采用的数据集是Oxford-IIIT Pet Dataset)。除了提供代码外,本小记对代码中不容易理解的内容,也进行了讲解。
本代码的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0,d2l的版本是1.0.3


文章目录

Pytorch应用小记 第一回

一、代码

二、小记

1.代码思路

2. glob.glob('dataset/images/*.jpg')

3 .xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]

4.xml = open(r'{}'.format(track)).read()

5. data_tree = etree.HTML(xml)

6.img_width = int(data_tree.xpath('//size/width/text()')[0])

7.label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))

8.num = np.random.permutation(len(images))

   data_images = np.array(images)[num]

9. imgs_data = np.repeat(imgs_data[:, :, np.newaxis], 3, axis=2)

10.resnet = torchvision.models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)

net_feature = resnet.fc.in_features

11.self.conv_base = nn.Sequential(*list(resnet.children())[:-1])


一、代码

代码如下所示:

import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
from lxml import etree
from matplotlib.patches import Rectangle
import glob
from PIL import Image
from torch.optim import lr_scheduler
from torchvision.models import ResNet101_Weights
import time# 创建输入图像
data_images = glob.glob('dataset/images/*.jpg')
data_xmls = glob.glob('dataset/annotations/xmls/*.xml')
xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in data_xmls]
images = [image for image in data_images ifimage.split('\\')[-1].split('.jpg')[0] in xmls_names]def transform_labels(track):xml = open(r'{}'.format(track)).read()data_tree = etree.HTML(xml)img_width = int(data_tree.xpath('//size/width/text()')[0])img_height = int(data_tree.xpath('//size/height/text()')[0])x_min = int(data_tree.xpath('//bndbox/xmin/text()')[0])y_min = int(data_tree.xpath('//bndbox/ymin/text()')[0])x_max = int(data_tree.xpath('//bndbox/xmax/text()')[0])y_max = int(data_tree.xpath('//bndbox/ymax/text()')[0])return [x_min / img_width, y_min / img_height, x_max / img_width, y_max / img_height]labels = [transform_labels(track) for track in data_xmls]
label_x_min, label_y_min, label_x_max, label_y_max = list(zip(*labels))
num = np.random.permutation(len(images))
data_images = np.array(images)[num]
# 数组或张量的形状调整为二维结构
label_x_min = np.array(label_x_min).astype(np.float32).reshape(-1, 1)[num]
label_y_min = np.array(label_y_min).astype(np.float32).reshape(-1, 1)[num]
label_x_max = np.array(label_x_max).astype(np.float32).reshape(-1, 1)[num]
label_y_max = np.array(label_y_max).astype(np.float32).reshape(-1, 1)[num]
segment = int(len(images) * 0.7)
train_images = data_images[:segment]
x_min_train_label = label_x_min[:segment]
y_min_train_label = label_y_min[:segment]
x_max_train_label = label_x_max[:segment]
y_max_train_label = label_y_max[:segment]test_images = data_images[segment:]
x_min_test_label = label_x_min[segment:]
y_min_test_label = label_y_min[segment:]
x_max_test_label = label_x_max[segment:]
y_max_test_label = label_y_max[segment:]img_scale = 224
transform = transforms.Compose([transforms.Resize((img_scale, img_scale)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 创建Dataset对象
class dataset_Oxford(data.Dataset):def __init__(self, images_in, out1_label, out2_label,out3_label, out4_label, transition):self.images = images_inself.out1_label = out1_labelself.out2_label = out2_labelself.out3_label = out3_labelself.out4_label = out4_labelself.transition = transitiondef __getitem__(self, index):img = self.images[index]out1_label = self.out1_label[index]out2_label = self.out2_label[index]out3_label = self.out3_label[index]out4_label = self.o
http://www.dtcms.com/a/182625.html

相关文章:

  • 汇编语言的温度魔法:单总线温度采集与显示的奇幻之旅
  • Python-函数
  • 备战菊厂笔试3
  • C# 使用 WinUI 3 项目模板创建桌面应用程序
  • C++GO语言微服务之图片、短信验证码生成及存储
  • Ajax基础
  • .Net HttpClient 管理客户端(初始化与生命周期管理)
  • 202534 | KafKa简介+应用场景+集群搭建+快速入门
  • kafka的安装及简单使用
  • [sklearn机器学习概述]机器学习-part3
  • 运算符与表达式 -《Go语言实战指南》
  • Scala与Go的异同教程
  • 【计算机视觉】OpenCV项目实战:基于OpenCV的图像分割技术深度解析与实践指南
  • 5.1 神经网络: 层和块
  • 电子电器架构 --- 车载以太网拓扑
  • k8s删除pv和pvc后,vg存储没释放分析
  • word换行符和段落标记
  • 2024年AI发展趋势全面解析:从多模态到AGI的突破
  • Python 从 SQLite 数据库中批量提取图像数据
  • 深拷贝与浅拷贝:理解 Python 中的对象复制机制
  • 数据格式(Data Format)设计
  • python3环境安装
  • redis八股--1
  • Redis 主从同步与对象模型(四)
  • JavaScript中对象和数组的常用方法
  • rust-candle学习笔记13-实现多头注意力
  • 嵌入式STM32学习——继电器
  • 大模型微调算法原理:从通用到专用的桥梁
  • 解决mybatisplus主键无法自增的问题
  • Spring之AOP