使用 Python 和 HuggingFace Transformers 进行对象检测
什么是物体检测?
看看你的周围。你很可能会看到电脑显示器、键盘和鼠标,或者,如果你使用移动浏览器,还会看到一部智能手机。
这些都是对象,或者说是某个特定类的实例。例如,下图描绘了一个人体案例。我们还观察到许多“瓶子”类的实例。类是蓝图,而对象是实物,它具有许多独特的特性,但由于具有共同的特性,因此仍然是类的成员。
照片和视频中有很多此类物品的例子。例如,在拍摄交通视频时,你可能会看到很多行人、汽车和自行车。知道它们出现在画面中可能会非常有帮助!
为什么?因为你可以数一数。它可以让你观察一个街区的拥挤程度。另一个例子是在拥堵区域检测停车位,这样你就可以停车了。
等等。
这就是物体检测的 用途!
物体检测和 Transformer
传统上,物体检测是使用卷积神经网络(CNN)进行的。通常,它们的架构是专门为物体检测而设计的,因为它们以图像作为输入,并输出图像的边界框。
如果您熟悉神经网络,就会知道卷积神经网络 (ConvNet) 非常擅长学习图像中的重要特征,而且它们具有空间不变性,这意味着学习到的对象在图像中的位置或大小都无关紧要。如果网络能够识别对象的属性并将其与特定类别关联起来,它就能识别它们。例如,许多不同的猫都可以被识别为属于猫类。
Transformer 设计最近在深度学习领域,尤其是在自然语言处理 (NLP) 领域引起了广泛关注。Transformer 的工作原理是将输入编码为高维状态,然后将其解码为所需的输出。得益于对自注意力机制的巧妙运用,Transformer 不仅能够学习识别特定模式,还能将这些模式与其他模式关联起来。例如,Transformer 可以学习将猫与其特征点(例如沙发)联系起来。
如果 Transformer 可以用来对图像进行分类,那么只需进一步将其用于检测物体即可。Carion 等人 (2020) 证明了使用基于 Transformer 的架构可以实现这一点。在他们的论文《使用 Transformer 的端到端对象检测》中,他们描述了检测 Transformer (DeTr),我们今天将利用它来构建我们的对象检测流程。
它的工作原理如下,甚至没有完全放弃 CNN:
- 使用卷积神经网络,从输入图像中提取重要特征。这些特征会进行位置编码,就像语言 Transformers 一样,以帮助神经网络学习这些特征在图像中的位置。
- 输入被平坦化,然后使用变换器编码器和注意力机制将其编码为中间状态。
- Transformer 解码器的输入是此状态以及在训练过程中获得的一组学习到的对象查询。你可以将它们想象成问题:“这里有一个物体吗?因为我之前在很多情况下都见过一个。” 这些问题将使用中间状态来回答。
- 解码器的输出是通过多个预测头进行的一组预测:每个查询一个。由于 DeTr 中的查询数量默认设置为 100,因此除非您进行其他配置,否则它只能预测一幅图像中的 100 个对象。
HuggingFace Transformers 及其对象检测管道
现在您了解了 DeTr 的工作原理,您可以使用它来构建实际的对象检测管道!
为此,我们将使用 HuggingFace Transformers,其设计旨在简化 NLP 和计算机视觉 Transformers 的处理。事实上,它的使用非常简单,只需加载 ObjectDetectionPipeline 即可,该管道默认加载一个使用 ResNet-50 主干网络训练的 DeTr Transformer 来生成图像特征。
现在让我们开始了解技术细节!
ObjectDetectionPipeline 可以轻松初始化为管道实例……换句话说,使用 pipeline(“object-detection”),如下例所示。根据 GitHub (nd) 的说明,当没有提供额外输入时,管道将按如下方式启动:
<span style="background-color:#f9f9f9"><span style="color:#242424"> <span style="color:#c41a16">“对象检测”</span>:{ <span style="color:#c41a16">“impl”</span>:ObjectDetectionPipeline,<span style="color:#c41a16">“tf”</span>:(),<span style="color:#c41a16">“pt”</span>:(AutoModelForObjectDetection,)<span style="color:#aa0d91">如果</span>is_torch_available()<span style="color:#aa0d91">否则</span>(),<span style="color:#c41a16">“默认”</span>:{ <span style="color:#c41a16">“模型”</span>:{ <span style="color:#c41a16">“pt”</span>:<span style="color:#c41a16">“facebook / detr-resnet-50”</span> }},<span style="color:#c41a16">“类型”</span>:<span style="color:#c41a16">“图像”</span>,},</span></span>
不出所料,我们使用了一个 ObjectDetectionPipeline 实例,该实例针对对象检测进行了优化。在 PyTorch 版本的 HuggingFace Transformers 中,我们使用了 AutoModelForObjectDetection 来实现这一点。有趣的是,目前还没有这个管道的 TensorFlow 实现……?!
正如您所了解的,默认情况下,该facebook/detr-resnet-50
模型用于获取图像特征:
DEtection TRansformer (DETR) 模型在 COCO 2017 目标检测数据集(118k 张带注释图像)上进行了端到端训练。该模型由 Carion 等人在论文《End-to-End Object Detection with Transformers》中提出。
HuggingFace(nd)
COCO 数据集(Common Objects in Context,常见对象数据集)是用于对象检测模型的标准数据集之一,并用于训练此模型。不用担心,您当然也可以训练自己的基于 DeTr 的模型。
重要提示!要使用ObjectDetectionPipeline
,必须timm
安装包含 PyTorch 图像模型的软件包。如果您尚未安装,请务必运行以下命令: pip install timm
。
使用 Python 实现简单的对象检测管道
现在让我们看一下使用 Python 实现一个简单的对象检测解决方案。
回想一下,我们正在使用 HuggingFace Transformers,它必须安装到您的系统上 -pip install transformers
如果您还没有安装它,请运行。
我们还假设已安装 PyTorch,它是目前最流行的深度学习库之一。如前所述,使用 pipeline(“object-detection”) 时将在后台加载的 ObjectDetectionPipeline 不包含 TensorFlow 实例;因此,需要 PyTorch。
这是我们将运行本文后面创建的对象检测管道的图像:
我们从导入开始:
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#aa0d91">从</span>变压器<span style="color:#aa0d91">导入</span>管道
<span style="color:#aa0d91">从</span> PIL <span style="color:#aa0d91">导入</span> 图像、ImageDraw、ImageFont</span></span>
显然,我们使用了transformers
,特别是它的pipeline
表示形式。然后,我们还使用了一个PIL
Python 库来加载、可视化和处理图像。具体来说,我们使用第一个导入来Image
加载图像并ImageDraw
绘制边界框和标签,后者也需要ImageFont.
说到两者,接下来是加载字体(我们选择 Arial)并初始化我们上面介绍的对象检测管道。
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#007400"># 加载字体</span>
font = ImageFont.truetype( <span style="color:#c41a16">"arial.ttf"</span> , <span style="color:#1c00cf">40</span> ) <span style="color:#007400"># 初始化对象检测管道</span>
object_detector = pipeline( <span style="color:#c41a16">"object-detection"</span> )</span></span>
如果您使用的是 Linux,请使用以下命令:
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#836c28">font</span> = ImageFont.true <span style="color:#aa0d91">type</span> ( <span style="color:#c41a16">"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"</span> , <span style="color:#1c00cf">40</span> )
</span></span>
然后,我们定义一个名为 draw_bounding_box 的函数,不出所料,它将用于绘制边界框。它接受图像 (im)、类别概率、边界框坐标、将应用此定义的边界框列表中的边界框索引以及列表长度作为输入。
- 首先,在图像顶部绘制实际的边界框,表示为
rounded_rectangle
红色且半径较小的 bbox,以确保边缘平滑。 - 其次,在边界框略上方绘制文本标签。
- 最后,返回中间结果,以便我们可以在上面绘制下一个边界框和标签。
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#007400"># 绘制边界框定义</span>
<span style="color:#aa0d91">def </span> draw_bounding_box ( <span style="color:#5c2699">im, score, label, xmin, ymin, xmax, ymax, index, num_boxes</span> ): <span style="color:#c41a16">""" 绘制边界框。 """ </span><span style="color:#5c2699">print</span> ( <span style="color:#c41a16">f"Drawing bounding box <span style="color:#000000">{index}</span> of <span style="color:#000000">{num_boxes}</span> ..."</span> ) <span style="color:#007400"># 绘制实际边界框</span>
im_with_rectangle = ImageDraw.Draw(im)
im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = <span style="color:#c41a16">"red"</span> , width = <span style="color:#1c00cf">5</span> , radius = <span style="color:#1c00cf">10</span> ) <span style="color:#007400"># 绘制标签</span>
im_with_rectangle.text((xmin+ <span style="color:#1c00cf">35</span> , ymin- <span style="color:#1c00cf">25</span> ), label, fill= <span style="color:#c41a16">"white"</span> , stroke_fill = <span style="color:#c41a16">"red"</span> , font = font) <span style="color:#007400"># 返回中间结果</span><span style="color:#aa0d91">return</span> im</span></span>
剩下的就是核心部分——使用,pipeline
然后根据其结果绘制边界框。
以下是我们的操作方法。
首先,图像(我们称之为street.jpg
,与 Python 脚本位于同一目录中)将被打开并存储在一个im
PIL 对象中。我们只需将其输入到初始化函数中,object_detector
这足以让模型返回边界框!剩下的交给 Transformers 库处理。
然后,我们将数据分配给一些变量并迭代每个结果,绘制边界框。
最后,我们将图像保存为street_bboxes.jpg
就是这样!
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#007400">#</span>
<span style="color:#aa0d91">使用</span>Image.open ( <span style="color:#c41a16">"street.jpg"</span> ) <span style="color:#aa0d91">as im</span><span style="color:#5c2699">打开</span>图像:<span style="color:#007400"># 执行物体检测</span>bounding_boxes = object_detector(im) <span style="color:#007400"># 迭代元素</span>num_boxes = <span style="color:#5c2699">len</span> (bounding_boxes) index = <span style="color:#1c00cf">0 </span><span style="color:#007400"># 为每个结果绘制边界框</span><span style="color:#aa0d91">for</span> bounding_box <span style="color:#aa0d91">in</span> bounding_boxes: <span style="color:#007400"># 获取实际框</span> box = bounding_box[ <span style="color:#c41a16">"box"</span> ] <span style="color:#007400"># 绘制边界框</span> im = draw_bounding_box(im, bounding_box[ <span style="color:#c41a16">"score"</span> ], bounding_box[ <span style="color:#c41a16">"label"</span> ],\ box[ <span style="color:#c41a16">"xmin"</span> ], box[ <span style="color:#c41a16">"ymin"</span> ], box[ <span style="color:#c41a16">"xmax"</span> ], box[ <span style="color:#c41a16">"ymax"</span> ], index, num_boxes) <span style="color:#007400"># 将索引增加一</span> index += <span style="color:#1c00cf">1 </span><span style="color:#007400"># 保存图像</span>im.save( <span style="color:#c41a16">"street_bboxes.jpg"</span> ) <span style="color:#007400"># 完成</span><span style="color:#5c2699">print</span> ( <span style="color:#c41a16">"Done!"</span> )</span></span>
使用不同的模型/使用您自己的模型进行对象检测
如果您确实创建了自己的模型或想要使用不同的模型,那么使用它来代替基于 ResNet-50 的 DeTr Transformer 非常容易。
这样做需要您将以下内容添加到导入中:
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#aa0d91">从</span>transforms<span style="color:#aa0d91">导入</span>DetrFeatureExtractor、DetrForObjectDetection</span></span>
然后,您可以初始化特征提取器和模型,并object_detector
用它们来初始化(而不是使用默认的)。例如,如果您想使用 ResNet-101 作为主干网络,可以按如下方式操作:
<span style="background-color:#f9f9f9"><span style="color:#242424"><span style="color:#007400"># 初始化另一个模型和特征提取器</span>
feature_extractor = DetrFeatureExtractor.from_pretrained( <span style="color:#c41a16">'facebook/detr-resnet-101'</span> )
model = DetrForObjectDetection.from_pretrained( <span style="color:#c41a16">'facebook/detr-resnet-101'</span> ) <span style="color:#007400"># 初始化对象检测管道</span>
object_detector = pipeline( <span style="color:#c41a16">"object-detection"</span> , model = model, feature_extractor = feature_extractor)</span></span>
结果
这是我们在输入图像上运行对象检测管道后得到的结果:
或者,放大后:
完整代码
对于想要立即开始使用的人,这里有完整的代码: