当前位置: 首页 > news >正文

【深入浅出PyTorch】--8.1.PyTorch生态--torchvision

🎯 进阶之路:走进 PyTorch 生态系统

恭喜你!经过前七章的学习,你已经掌握了 PyTorch 的核心组件:

  • 张量操作与自动求导(torch.Tensorautograd
  • 神经网络构建(torch.nn
  • 模型训练流程(前向传播、损失计算、反向传播、优化器)
  • 自定义模型结构与模块化设计
  • 使用 TensorBoard 等工具进行可视化分析

现在,是时候将这些技能应用到真实世界的任务中了。而要做到这一点,离不开 PyTorch 生态系统 的支持。

🌐 什么是 PyTorch 生态系统?

PyTorch 之所以成为深度学习研究和工业界的主流框架之一,不仅仅因为它 API 设计简洁、动态图机制灵活,更重要的是其背后强大的 开源社区生态。这个生态系统由一系列高质量的第三方库组成,它们基于 PyTorch 构建,专注于解决特定领域的实际问题。

这些工具包通常提供:

功能示例
数据加载与预处理DatasetDataLoader 的高级封装
数据增强领域专用的图像/文本/图数据增广方法
预定义模型结构ResNet, BERT, GCN 等经典模型一键调用
预训练权重ImageNet、COCO、WikiText 等大规模数据集上的预训练模型
损失函数与评估指标分类、检测、分割等任务专用指标
训练/推理流水线封装简化训练循环,支持分布式训练
可视化与调试工具特征图可视化、注意力机制展示等

🔧 主要领域及其代表性工具包

领域工具包简介
计算机视觉TorchVision官方视觉库,包含数据集(CIFAR, ImageNet)、模型(ResNet, Faster R-CNN)、变换(transforms)等
视频理解TorchVideo(已并入 PyTorchVideo)Facebook AI 开发,支持视频分类、动作识别等
自然语言处理torchtext官方 NLP 库,支持文本数据处理、词向量、语言模型等(注意:新版已简化 API)
替代方案:Hugging Face Transformers + Datasets
图神经网络PyG (PyTorch Geometric)德国达姆施塔特工业大学开发,支持图卷积、图注意力等复杂结构
音频处理torchaudio官方音频库,支持语音识别、音频分类、声纹识别等
强化学习TorchRL / Stable-Baselines3 (支持 PyTorch)提供 RL 算法实现与环境接口
模型部署[TorchScriptTorchServeONNX]支持模型导出、序列化与生产环境部署

💡 小贴士:虽然 torchtexttorchaudio 是官方库,但在实际项目中,许多开发者更倾向于使用 Hugging Face TransformersKaldi 等更成熟的生态工具。


目录

1.torchvision--简介

1.1.torchvision.datasets(数据集)

📷 一、图像分类(Image Classification)

🏙️ 二、目标检测与语义分割(Object Detection & Segmentation)

👤 三、人脸识别与人脸分析(Face Recognition / Analysis)

🚗 四、自动驾驶与三维感知(Autonomous Driving & 3D Vision)

🎥 五、视频理解(Video Understanding)

🌆 六、场景识别与自然图像(Scene Recognition)

🧠 七、合成与特殊用途数据集(Synthetic / Utility)

1.2.torchvision.transforms(数据增强)

1.3.torchvision.models(模型加载)

1.3.1.图像分类(Image Classification)

1.3.2.语义分割(Semantic Segmentation)

1.3.3.目标检测、实例分割与关键点检测

1.4.torchvision.io(数据读写)

1.5.torchvision.ops(视觉操作)

1.6.torchvision.utils(可视化)


1.torchvision--简介

在前面的学习和实战中,我们经常会用到torchvision来调用预训练模型,加载数据集,对图片进行数据增强的操作。在本章我们将给大家简单介绍下torchvision以及相关操作。

" The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision. "

正如引言介绍的一样,我们可以知道torchvision包含了在计算机视觉中常常用到的数据集,模型和图像处理的方式,而具体的torchvision则包括了下面这几部分,带 ***** 的部分是我们经常会使用到的一些库,所以在下面的部分我们对这些库进行一个简单的介绍:

  • torchvision.datasets *

  • torchvision.models *

  • torchvision.tramsforms *

  • torchvision.io

  • torchvision.ops

  • torchvision.utils

1.1.torchvision.datasets(数据集)

torchvision.datasets主要包含了一些我们在计算机视觉中常见的数据集,在==0.10.0版本==的torchvision下,有以下的数据集:

任务推荐数据集
图像分类CIFAR-10/100, Fashion-MNIST, STL10, Caltech
场景识别Places365
目标检测VOC, COCO, WIDERFace
语义分割Cityscapes, SBD, VOC
人脸分析CelebA, WIDERFace
动作识别Kinetics-400, UCF101
自动驾驶KITTI, Cityscapes
模型调试FakeData
字符识别EMNIST, KMNIST, QMNIST, SVHN

📷 一、图像分类(Image Classification)

这些数据集主要用于图像分类任务,是深度学习入门和模型验证的经典基准。

数据集简介
CIFAR-10 / CIFAR-100小尺寸彩色图像数据集(32×32),分别包含10类和100类物体。常用于测试模型性能。
Fashion-MNIST10类服装的灰度图(28×28),替代原始 MNIST 手写数字,更具挑战性。
KMNISTKuzushiji-MNIST,日文草书字符的灰度图像(28×28),用于字符识别研究。
EMNIST扩展版 MNIST,包含手写字母和数字,结构与 MNIST 兼容但类别更多。
QMNISTMNIST 的扩展版本,来自原始 NIST 数据库,可用于更精确的实验对比。
SVHN (Street View House Numbers)街景门牌号码图像(32×32 彩色),真实场景下的数字识别任务,比 MNIST 更复杂。
STL10类似 CIFAR-10,但图像更大(96×96),且提供无标签数据用于半监督学习。

🏙️ 二、目标检测与语义分割(Object Detection & Segmentation)

用于目标检测、实例分割、语义分割等高级视觉任务。

数据集简介
VOC (Pascal Visual Object Classes)经典多任务数据集(2007/2012),包含 20 类物体,支持分类、检测、分割任务。
COCO (via torchvision 支持加载)虽然 torchvision 不直接命名 COCO,但可通过 CocoDetectionCocoCaptions 接口加载(需手动下载)。大规模目标检测、分割、图像描述数据集。
Cityscapes城市场景下的语义分割数据集,包含 5000 张精细标注图像,30 类城市交通元素(车、行人、道路等)。
SBD (Semantic Boundaries Dataset)VOC 的扩展,提供像素级语义分割标注,常用于语义分割模型训练。
WIDERFace大规模人脸检测数据集,包含各种姿态、遮挡、光照条件下的人脸,标注了边界框和关键点。

👤 三、人脸识别与人脸分析(Face Recognition / Analysis)

专注于人脸相关任务,如检测、关键点定位、属性识别。

数据集简介
CelebA (CelebFaces Attributes Dataset)包含超过 20 万张名人脸部图像,每张标注了 40 种属性(如微笑、戴眼镜)和 5 个关键点。
EMNIST(也适用于手写识别)包含手写英文字母,可用于字符级别的人脸之外的模式识别。

🚗 四、自动驾驶与三维感知(Autonomous Driving & 3D Vision)

用于自动驾驶相关的视觉任务,结合图像与激光雷达数据。

数据集简介
KITTI自动驾驶基准数据集,包含图像、激光雷达点云、GPS/IMU 数据,支持目标检测、立体匹配、光流估计等任务。
Cityscapes同上,也广泛用于自动驾驶中的语义分割任务。

🎥 五、视频理解(Video Understanding)

用于视频分类、动作识别等时序建模任务。

数据集简介
Kinetics-400大规模动作识别数据集,包含约 24 万段 YouTube 视频,涵盖 400 种人类动作(如跳舞、游泳)。
UCF101视频动作识别经典数据集,101 类动作,约 1.3 万段视频,适合小规模实验。

🌆 六、场景识别与自然图像(Scene Recognition)

用于场景分类、图像检索等任务。

数据集简介
Places365包含 365 种室内外场景的大规模数据集(如厨房、森林、机场),适合场景分类模型预训练。
PhotoTour (Liberty, Yosemite, Notre Dame)用于图像匹配和特征提取,常用于训练描述子(如 SuperPoint),基于旅游照片的匹配任务。

🧠 七、合成与特殊用途数据集(Synthetic / Utility)

用于调试、测试或教学目的。

数据集简介
FakeData合成随机数据集,不依赖磁盘文件,用于快速测试数据加载流程和模型结构是否正确。
Caltech-101 / Caltech-256图像分类数据集,分别包含 101 和 256 类物体,图像质量高,类别多样。

调用可以使用代码:

import torchvision.datasets as datasets
import torchvision.transforms as transforms# 示例:加载 CIFAR-10
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)# 示例:加载 Cityscapes(需指定 split)
cityscape = datasets.Cityscapes(root='./cityscapes', split='train', mode='fine', target_type='semantic', transform=transform)# 示例:加载 Kinetics-400(注意:需要大量存储空间)
kinetics = datasets.Kinetics400(root='./kinetics', frames_per_clip=16, step_between_clips=5)

1.2.torchvision.transforms(数据增强)

我们知道在计算机视觉中处理的数据集有很大一部分是图片类型的,如果获取的数据是格式或者大小不一的图片,则需要进行归一化和大小缩放等操作,这些是常用的数据预处理方法。除此之外,当图片数据有限时,我们还需要通过对现有图片数据进行各种变换,如缩小或放大、水平或垂直翻转等,这些是常见的数据增强方法。而torchvision.transforms中就包含了许多这样的操作。在之前第四章的Fashion-mnist实战中对数据的处理时我们就用到了torchvision.transformer:

from torchvision import transforms
data_transform = transforms.Compose([transforms.ToPILImage(),   # 这一步取决于后续的数据读取方式,如果使用内置数据集则不需要transforms.Resize(image_size),transforms.ToTensor()
])

除了上面提到的几种数据增强操作,在torchvision官方文档里提到了更多的操作,具体使用方法也可以参考本节配套的”transforms.ipynb“,在这个notebook中我们给出了常见的transforms的API及其使用方法,更多数据变换的操作我们可以点击这里进行查看。

1.3.torchvision.models(模型加载)

为了提升深度学习研究与应用的效率,避免“重复造轮子”,PyTorch 团队通过 torchvision.models 模块提供了大量在大规模数据集上预训练好的模型。这些模型不仅可用于迁移学习(Transfer Learning),也可以作为新模型设计的基准或特征提取器使用。

🔗 官方文档地址:https://pytorch.org/vision/stable/models.html

此外,如果你希望获取更多非官方但高质量的预训练模型实现,可以参考社区项目:

🔗 pretrained-models.pytorch —— 包含 DenseNet-BC、SE-Net 等多种扩展模型

1.3.1.图像分类(Image Classification)

所有分类模型均在 ImageNet-1K 数据集 上进行预训练,包含 1000 类自然图像,输入尺寸通常为 (3, 224, 224)

模型名称特点简介
AlexNet2012 年 ImageNet 冠军,首次使用 GPU 加速训练的 CNN,奠定现代 CNN 基础。
VGG (VGG11/13/16/19)结构简单规整,全部使用 3×3 卷积堆叠,适合教学理解。
ResNet (18/34/50/101/152)引入残差连接(Residual Block),解决深层网络梯度消失问题,广泛用于各种任务 backbone。
SqueezeNet参数极少(<5MB),轻量级设计,“Fire Module” 实现高效压缩。
DenseNet每一层都与后续层直接相连,增强特征复用,减少参数冗余。
Inception v3 / GoogLeNet使用多分支结构(Inception Module)提升感受野多样性;GoogLeNet 是原始版本,Inception v3 是改进版。
ShuffleNet v2针对移动端优化,强调内存访问效率(MAC),适合嵌入式设备。
MobileNetV2 / V3轻量化代表,使用倒置残差 + 线性瓶颈模块;V3 结合 NAS 搜索进一步优化。
ResNeXt分组卷积 + 残差结构,提升模型表达能力(如 ResNeXt-50)。
Wide ResNet在 ResNet 基础上加宽通道数,提升性能而非加深层数。
MNASNet使用神经架构搜索(NAS)自动设计的轻量模型,平衡精度与延迟。
EfficientNet (B0-B7)使用复合缩放策略统一调整深度、宽度、分辨率,在小模型到大模型间取得最优性能。
RegNetFacebook 提出的新一代 backbone 设计原则,参数更可控,性能优于 ResNeXt。

📌 查看准确率
👉 ImageNet 分类模型性能对比表

1.3.2.语义分割(Semantic Segmentation)

模型在 COCO train2017 的 stuffthingmaps 子集 上训练,输出每个像素的类别标签,共 21 类(含背景):

classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle','bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
模型主干网络(Backbone)特点
FCN-ResNet50 / FCN-ResNet101ResNet50/101全卷积网络鼻祖,端到端像素预测。
DeepLabV3-ResNet50 / DeepLabV3-ResNet101ResNet50/101使用空洞卷积(ASPP)扩大感受野,边界处理更好。
LR-ASPP MobileNetV3-LargeMobileNetV3-Large轻量级实时分割模型,适用于移动场景。
DeepLabV3-MobileNetV3-LargeMobileNetV3-Large高精度轻量分割方案,兼顾速度与效果。

🎯 输出形式:(N, num_classes, H, W),需通过 argmax(dim=1) 获取预测类别图。

📊 评估指标

  • Mean IoU (Intersection over Union):平均交并比
  • Pixel Accuracy:全局像素正确率

📌 查看性能指标
👉 语义分割模型性能对比

1.3.3.目标检测、实例分割与关键点检测

模型均在 COCO train2017 数据集上训练,支持以下任务:

📌 类别信息(共 91 类,部分标记为 'N/A')

COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', ..., 'toothbrush'
]  # 共 91 类(含背景)

📌 人体关键点名称(共 17 个关键点)

COCO_PERSON_KEYPOINT_NAMES = ['nose','left_eye', 'right_eye','left_ear', 'right_ear','left_shoulder', 'right_shoulder','left_elbow', 'right_elbow','left_wrist', 'right_wrist','left_hip', 'right_hip','left_knee', 'right_knee','left_ankle', 'right_ankle'
]

✅ 支持的模型架构

模型支持任务特点
Faster R-CNN目标检测两阶段检测器经典之作,精度高,速度适中。
Mask R-CNN检测 + 实例分割 + 关键点在 Faster R-CNN 基础上增加 mask head 和 keypoint head,全能型模型。
RetinaNet目标检测单阶段检测器,引入 Focal Loss 解决正负样本不平衡问题。
SSD / SSDlite目标检测单阶段高速检测器;SSDlite 使用 MobileNet 为主干,适合移动端。

🎯 输出格式(以 Mask R-CNN 为例):

{'boxes': tensor(N, 4),        # 边界框 [x1, y1, x2, y2]'labels': tensor(N,),         # 对应类别 ID'scores': tensor(N,),         # 置信度得分'masks': tensor(N, 1, H, W),  # 实例分割掩码'keypoints': tensor(N, 17, 3) # (x, y, visibility)
}

📊 评估指标(COCO 标准)

  • Box AP:边界框平均精度(IoU=0.5:0.95)
  • Mask AP:实例分割 AP
  • Keypoint AP:关键点检测 AP

📌 查看性能指标
👉 目标检测与分割模型性能对比

1.4.torchvision.io(数据读写)

torchvision.io提供了视频、图片和文件的 IO 操作的功能,它们包括读取、写入、编解码处理操作。随着torchvision的发展,io也增加了更多底层的高效率的API。在使用torchvision.io的过程中,我们需要注意以下几点:

  • 不同版本之间,torchvision.io有着较大变化,因此在使用时,需要查看下我们的torchvision版本是否存在你想使用的方法。

  • 除了read_video()等方法,torchvision.io为我们提供了一个细粒度的视频API torchvision.io.VideoReader() ,它具有更高的效率并且更加接近底层处理。在使用时,我们需要先安装ffmpeg然后从源码重新编译torchvision我们才能我们能使用这些方法。

  • 在使用Video相关API时,我们最好提前安装好PyAV这个库。

图像读写(Image I/O)

import torch
import torchvision.io as io
from torchvision.utils import save_image
import matplotlib.pyplot as plt# 读取图像为 Tensor (C, H, W),值范围 [0, 255],dtype=torch.uint8
image_tensor = io.read_image("example.jpg")  # 自动判断格式(JPEG/PNG等)print(image_tensor.shape)  # e.g., torch.Size([3, 480, 640])
print(image_tensor.dtype)  # torch.uint8# 写入图像
io.write_jpeg(image_tensor, "output.jpg")        # 仅支持 uint8
io.write_png(image_tensor, "output.png")# 转换为 float 并归一化用于模型输入
image_float = image_tensor.float() / 255.0

1.5.torchvision.ops(视觉操作)

这些操作在目标检测、分割等任务中至关重要。

torchvision.ops 为我们提供了许多计算机视觉的特定操作,包括但不仅限于NMS,RoIAlign(MASK R-CNN中应用的一种方法),RoIPool(Fast R-CNN中用到的一种方法)。在合适的时间使用可以大大降低我们的工作量,避免重复的造轮子,想看更多的函数介绍可以点击这里进行细致查看。

✅ 1. NMS(非极大值抑制)

from torchvision.ops import nms
import torchboxes = torch.tensor([[0, 0, 100, 100],[50, 50, 150, 150],[80, 80, 180, 180]
], dtype=torch.float)scores = torch.tensor([0.9, 0.8, 0.7])keep_indices = nms(boxes, scores, iou_threshold=0.5)
print(keep_indices)  # 保留的框索引,如 tensor([0, 2])

1.6.torchvision.utils(可视化)

torchvision.utils 为我们提供了一些可视化的方法,可以帮助我们将若干张图片拼接在一起、可视化检测和分割的效果。具体方法可以点击这里进行查看。

总的来说,torchvision的出现帮助我们解决了常见的计算机视觉中一些重复且耗时的工作,并在数据集的获取、数据增强、模型预训练等方面大大降低了我们的工作难度,可以让我们更加快速上手一些计算机视觉任务。

 1. 图像网格拼接(make_grid

# -*- coding: utf-8 -*-
"""
可视化工具示例:使用 torchvision.utils.make_grid 和 draw_bounding_boxes
展示图像网格拼接与目标检测结果可视化
"""import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid, draw_bounding_boxes
from torchvision.io import read_image
from PIL import Image
import matplotlib.pyplot as plt# -------------------------------
# 1. 图像网格拼接(make_grid 示例)
# -------------------------------# 生成一些随机图像数据(模拟模型输出或数据集样本)
# 假设是 [-1, 1] 范围内的张量(例如 GAN 输出)
images = torch.randn(16, 3, 64, 64)# 方法一:使用 normalize=True 自动归一化到 [0,1]
# ❌ 错误写法(旧版): range=(-1, 1)
# grid = make_grid(images, nrow=4, padding=2, normalize=True, range=(-1, 1))# ✅ 正确写法(新版本推荐)
grid = make_grid(images, nrow=4, padding=2, normalize=True)  # 自动处理 [-1,1] -> [0,1]# 显示图像网格
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).numpy())  # CHW -> HWC
plt.axis("off")
plt.title("Image Grid (Normalized)")
plt.show()# -------------------------------
# 2. 加载真实图像并绘制边界框
# -------------------------------# 尝试读取本地图片(请确保路径正确,或替换为你的图片路径)
image_path = "example.jpg"  # 修改为你自己的图片路径try:image_tensor = read_image(image_path)  # (C, H, W), dtype=torch.uint8
except FileNotFoundError:print(f"[警告] 找不到 {image_path},使用随机图像代替。")image_tensor = (torch.rand(3, 256, 384) * 255).to(torch.uint8)# 定义一些边界框(格式:[x1, y1, x2, y2])
boxes = torch.tensor([[50, 50, 200, 200],[100, 150, 300, 250]
])labels = ["Person", "Car"]
colors = ["red", "blue"]# 使用 draw_bounding_boxes 绘制框(返回 uint8 tensor)
try:image_with_boxes = draw_bounding_boxes(image_tensor,boxes=boxes,labels=labels,colors=colors,width=3,font_size=20  # 如果提示找不到字体,可尝试注释此行)
except TypeError as e:if "font_size" in str(e):print("[提示] font_size 不支持,可能是 torchvision 版本较老,忽略该参数。")image_with_boxes = draw_bounding_boxes(image_tensor,boxes=boxes,labels=labels,colors=colors,width=3)else:raise e# 转为 PIL 图像显示
pil_image = Image.fromarray(image_with_boxes.permute(1, 2, 0).numpy())
pil_image.show()  # 或者用 plt 显示

http://www.dtcms.com/a/565344.html

相关文章:

  • Blender新手入门,超详细!!!
  • Milvus:数据库层操作详解(二)
  • Blender入门学习09 - 制作动画
  • 网站建设终身不用维护网络推广主要内容
  • 金融知识详解:隔日差错处理机制与银行实战场景
  • 网站运营编辑浙江久天建设有限公司网站
  • 做网站销售说辞有赞商城官网登录
  • MATLAB实现基于RPCA的图像稀疏低秩分解
  • 象山企业门户网站建设扬州高端网站制作
  • 服务器网站建设维护app制作定制开发
  • php企业网站开发方案服装外贸网站建设
  • 【Go】--互斥锁和读写锁
  • 《从适配器本质到面试题:一文掌握 C++ 栈、队列与优先级队列核心》
  • 心理咨询网站模板做网站手机
  • 光学3D表面轮廓仪中Rz代表什么?如何精准测量Rz?
  • ps做登录网站北京网站制作工作室
  • git rebase提交
  • vue3引入icon-font
  • 基于开源操作系统搭建K8S高可用集群
  • 学做网站论坛 可以吗做网站是不是太麻烦了
  • leetcode 1578 使绳子变成彩色的最短时间
  • 中国建设银行网上银行官方网站长沙优秀网站建设
  • 1.7 Foundry介绍
  • 什么是向量数据库?主流产品介绍与实战演练
  • redission实现延时队列
  • 浏览器端缓存地图请求:使用 IndexedDB + ajax-hook 提升地图加载速度
  • 地铁工程建设论文投稿网站谷歌广告代运营
  • 广东备案网站软件开发怎么学
  • 【成长纪实】鸿蒙 ArkTS 语言从零到一完整指南
  • PyTorch模型部署实战:从TorchScript到LibTorch的完整路径