小白的进阶之路系列之三----人工智能从初步到精通pytorch计算机视觉详解上
计算机视觉是教计算机看东西的艺术。
例如,它可能涉及构建一个模型来分类照片是猫还是狗(二元分类)。
或者照片是猫、狗还是鸡(多类分类)。
或者识别汽车出现在视频帧中的位置(目标检测)。
或者找出图像中不同物体可以被分离的位置(全视分割)。
计算机视觉应用在哪里?
如果你使用智能手机,你已经使用了计算机视觉。
相机和照片应用程序使用计算机视觉来增强和分类图像。
现代汽车使用计算机视觉来避开其他车辆并保持在车道线内。
制造商使用计算机视觉来识别各种产品中的缺陷。
安全摄像头使用计算机视觉来探测潜在的入侵者。
本质上,任何可以用视觉描述的东西都可能是潜在的计算机视觉问题。
我们要讲的内容
我们将把在过去几节中学习的PyTorch 工作流程应用到计算机视觉中。
今天这篇文章将涵盖一下一些内容:
主题 | 内容 |
---|---|
PyTorch中的计算机视觉库 | PyTorch有很多内置的有用的计算机视觉库,让我们来看看。 |
载入数据 | 为了练习计算机视觉,我们将从FashionMNIST上的一些不同服装的图像开始。 |
准备数据 | 我们有一些图像,让我们用PyTorch DataLoader加载它们,这样我们就可以在训练循环中使用它们。 |
Model 0:建立基线模型 | 这里我们将创建一个多类分类模型来学习数据中的模式,我们还将选择损失函数,优化器并构建训练循环。 |
做出预测并评估Model 0 | 让我们用我们的基线模型做一些预测并评估它们。 |
为将来的模型设置通用代码(设备无关) | 编写与设备无关的代码是最佳实践,因此让我们来设置它。 |
0 PyTorch中的计算机视觉库
在我们开始编写代码之前,让我们讨论一下您应该了解的一些PyTorch计算机视觉库。
Pytorch模块 | 功能 |
---|---|
torchvision | 包含数据集,模型架构和图像转换,通常用于计算机视觉问题。 |
torchvision.datasets | 在这里,您将找到许多示例计算机视觉数据集,用于图像分类,目标检测,图像字幕,视频分类等一系列问题。它还包含一系列用于创建自定义数据集的基类。 |
torchvision.models | 此模块包含在PyTorch中实现的性能良好且常用的计算机视觉模型体系结构,您可以将其用于解决自己的问题。 |
torchvision.transforms | 在与模型一起使用之前,通常需要对图像进行转换(转换为数字/处理/增强),这里可以找到常见的图像转换。 |
torch.utils.data.Dataset | PyTorch的基本数据集类。 |
torch.utils.data.DataLoader | 在数据集上创建一个Python可迭代对象(使用torch.utils.data.Dataset创建)。 |
[!TIP]
注意:torch.utils.data.Dataset和torch.utils.data.DataLoader类不仅用于PyTorch中的计算机视觉,它们还能够处理许多不同类型的数据。
现在我们已经介绍了一些最重要的PyTorch计算机视觉库,让我们导入相关的依赖项。
# Import PyTorch
import torch
from torch import nn# Import torchvision
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor# Import matplotlib for visualization
import matplotlib.pyplot as plt# Check versions
# Note: your PyTorch version shouldn't be lower than 1.10.0 and torchvision version shouldn't be lower than 0.11
print(f"PyTorch version: {torch.__version__}\ntorchvision version: {torchvision.__version__}")
输出为:
PyTorch version: 2.7.0+cu118
torchvision version: 0.22.0+cu118
1 获取数据集
为了开始研究计算机视觉问题,让我们获得一个计算机视觉数据集。
我们从FashionMNIST开始。
MNIST代表修改后的国家标准与技术研究所。
原始的MNIST数据集包含数千个手写数字(从0到9)的示例,并用于构建计算机视觉模型来识别邮政服务的数字。
Zalando Research制作的FashionMNIST也是类似的设置。
除了它包含了10种不同服装的灰度图像。
torchvision.datasets包含许多示例数据集,您可以使用它们来练习编写计算机视觉代码。FashionMNIST就是其中一个数据集。由于它有10个不同的图像类别(不同类型的服装),所以它是一个多类别分类问题。
稍后,我们将建立一个计算机视觉神经网络来识别这些图像中不同风格的服装。
PyTorch有一堆常见的计算机视觉数据集存储在torchvision.datasets中。
在torchvision.datasets.FashionMNIST()中包含FashionMNIST。
要下载它,我们提供以下参数:
-
root: STR -您要将数据下载到哪个文件夹?
-
train:Bool -你想要训练还是测试分割?
-
download:Bool -数据应该下载吗?
-
transform: torchvision.transform—您希望对数据进行哪些转换?
-
Target_transform—如果您喜欢,也可以转换目标(标签)。
torchvision中的许多其他数据集都有这些参数选项。
# Setup training data
train_data = datasets.FashionMNIST(root="data", # where to download data to?train=True, # get training datadownload=True, # download data if it doesn't exist on disktransform=ToTensor(), # images come as PIL format, we want to turn into Torch tensorstarget_transform=None # you can transform labels as well
)# Setup testing data
test_data = datasets.FashionMNIST(root="data",train=False, # get test datadownload=True,transform=ToTensor()
)
输出为:
100.0%
100.0%
100.0%
100.0%
让我们来看看训练数据的第一个样本。
# See first training sample
image, label = train_data[0]
print(image, label)
输出为:
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,0.0000, 0.0039, 0.0039, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,0.0157, 0.0000, 0.0000, 0.0118],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,0.0000, 0.0471, 0.0392, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,0.3020, 0.5098, 0.2824, 0.0588],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,0.5529, 0.3451, 0.6745, 0.2588],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,0.4824, 0.7686, 0.8980, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824,