目标检测:从基础原理到前沿技术全面解析
引言
在计算机视觉领域,目标检测是一项核心且极具挑战性的任务,它不仅要识别图像中有什么物体,还要确定这些物体在图像中的具体位置。随着人工智能技术的快速发展,目标检测已成为智能监控、自动驾驶、医疗影像分析等众多应用的基础技术。本文将全面介绍目标检测的基础概念、发展历程、关键技术、实践应用以及未来趋势,为读者提供系统性的知识框架。
第一章 目标检测概述
1.1 目标检测的定义与重要性
目标检测(Object Detection)是计算机视觉中的一项关键任务,其核心目标是在给定图像中精确定位并识别出感兴趣的物体实例。与简单的图像分类不同,目标检测需要解决"在哪里"和"是什么"两个问题,输出通常是物体边界框(Bounding Box)和类别标签的组合。
技术定义:给定输入图像I,目标检测的任务是找出所有感兴趣的物体实例,并为每个实例输出一个边界框b=(x,y,w,h)和类别标签c∈{1,2,…,C},其中(x,y)表示框的中心或角点坐标,(w,h)表示框的宽度和高度,C是预定义的类别数量。
目标检测的重要性体现在多个方面:
- 基础性:是许多高级视觉任务(如实例分割、行为识别)的基础
- 应用广泛:从安防监控到自动驾驶,从工业质检到医疗诊断
- 商业价值:全球计算机视觉市场规模预计2025年将突破200亿美元
- 研究价值:推动了深度学习、特征表示等领域的发展
1.2 目标检测与相关任务的比较
为了更好地理解目标检测,有必要将其与相关计算机视觉任务进行对比:
任务类型 | 输出形式 | 典型应用 | 主要挑战 |
---|---|---|---|
图像分类 | 整个图像的类别标签 | 相册分类、场景识别 | 视角变化、背景干扰 |
目标检测 | 多个边界框+类别 | 自动驾驶、安防监控 | 物体重叠、尺度变化 |
语义分割 | 像素级类别标注 | 医疗影像、遥感图像 | 精细边界、计算成本 |
实例分割 | 像素级实例标注 | 机器人抓取、AR应用 | 实例区分、遮挡处理 |
关键点检测 | 特定点位置 | 姿态估计、人脸识别 | 点定位精度、遮挡 |
1.3 目标检测的核心挑战
目标检测面临诸多技术挑战,主要包括:
- 尺度变化:同一类物体在不同图像中可能呈现极大尺寸差异
- 视角变化:摄像机角度导致物体外观显著不同
- 遮挡问题:目标物体被部分遮挡,仅可见局部特征
- 光照条件:光线变化影响物体外观表现
- 背景干扰:复杂背景与目标物体特征相似
- 类别不平衡:某些类别样本数量远多于其他类别
- 实时性要求:许多应用场景需要高帧率处理
- 小物体检测:图像中小尺寸物体的识别与定位困难
1.4 目标检测的发展历程
目标检测技术的发展大致经历了以下几个阶段:
-
传统方法时代(2001-2012):
- 基于手工设计特征(如HOG、SIFT)
- 滑动窗口+分类器(如SVM)
- 代表工作:Viola-Jones人脸检测、DPM(Deformable Part Model)
-
深度学习初期(2012-2015):
- 两阶段检测器兴起(R-CNN系列)
- 从手工特征到CNN特征转变
- 代表工作:R-CNN、SPPNet、Fast R-CNN
-
快速发展期(2015-2017):
- 单阶段检测器出现(YOLO、SSD)
- 检测效率大幅提升
- 代表工作:Faster R-CNN、YOLOv1/v2、SSD
-
架构创新期(2017-2020):
- 特征金字塔网络(FPN)
- Anchor-free方法兴起
- 代表工作:RetinaNet、CornerNet、CenterNet
-
Transformer时代(2020-至今):
- Vision Transformer应用于检测
- 端到端检测器
- 代表工作:DETR、Swin Transformer、YOLOS
第二章 传统目标检测方法
在深度学习统治计算机视觉之前,传统目标检测方法主要依靠精心设计的特征提取和机器学习算法。这些方法虽然性能不及现代深度学习方法,但其中的许多思想至今仍有借鉴价值。
2.1 特征提取方法
2.1.1 Haar-like特征
由Viola和Jones提出的人脸检测特征:
- 计算图像矩形区域的像素和差值
- 通过积分图加速计算
- 特征简单但有效,适合人脸等刚体检测
import cv2# 加载预训练的Haar级联分类器
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')def detect_faces(image_path):img = cv2.imread(image_path)gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)# 检测人脸faces = face_cascade.detectMultiScale(gray,scaleFactor=1.1,minNeighbors=5,minSize=(30, 30))# 绘制检测框for (x,y,w,h) in faces:cv2.rectangle(img,(x,y),(x+w,y+h),(255,0,0),2)cv2.imshow('Faces detected', img)cv2.waitKey(0)cv2.destroyAllWindows()
2.1.2 HOG(方向梯度直方图)
Navneet Dalal提出的特征描述子:
- 计算图像梯度方向和大小
- 将图像划分为细胞单元
- 统计每个单元的梯度方向直方图
- 块归一化增强光照不变性
from skimage.feature import hog
from skimage import exposure
import matplotlib.pyplot as pltdef extract_hog_features(image_path):image = cv2.imread(image_path)gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 计算HOG特征和可视化fd, hog_image = hog(gray, orientations=8,pixels_per_cell=(16,16),cells_per_block=(1,1),visualize=True)# 显示HOG特征fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))ax1.imshow(gray, cmap=plt.cm.gray)ax1.set_title('Input image')hog_image_rescaled = exposure.rescale_intensity(hog_image, in_range=(0,10))ax2.imshow(hog_image_rescaled, cmap=plt.cm.gray)ax2.set_title('HOG features')plt.show()return fd
2.1.3 SIFT(尺度不变特征变换)
David Lowe提出的局部特征描述子:
- 尺度空间极值检测
- 关键点定位
- 方向分配
- 关键点描述子生成
2.2 检测框架
2.2.1 滑动窗口
最朴素的检测方法:
- 用不同大小的窗口扫描图像
- 对每个窗口提取特征并分类
- 合并重叠检测结果
缺点:计算量大,效率低下
2.2.2 选择性搜索
生成可能包含物体的区域提议:
- 基于颜色、纹理、大小等相似性合并超像素
- 生成不同层次的区域提议
- 减少需要分类的窗口数量
2.2.3 可变形部件模型(DPM)
Felzenszwalb提出的经典方法:
- 将物体建模为根滤波器和部件滤波器的组合
- 考虑部件之间的几何变形惩罚
- 使用潜变量SVM进行训练
2.3 传统方法的局限性
尽管传统方法在特定场景下仍有用武之地,但普遍存在以下问题:
- 特征设计困难:需要专业知识设计特征提取器
- 泛化能力弱:手工特征难以适应多样化的物体外观
- 多尺度处理复杂:需要单独处理不同尺度
- 遮挡处理不足:对部分遮挡的物体识别效果差
- 计算效率低:滑动窗口等方式计算量大
这些局限性促使研究者转向基于深度学习的方法,后者能够自动学习更适合目标检测的特征表示。
第三章 基于深度学习的目标检测方法
深度学习彻底改变了目标检测领域,通过端到端的学习方式大幅提升了检测性能。本章将详细介绍深度学习时代的目标检测方法。
3.1 两阶段检测器
两阶段检测器首先生成区域提议(Region Proposal),然后对这些提议进行分类和回归,精度高但速度相对较慢。
3.1.1 R-CNN系列
R-CNN(2014):
- 使用选择性搜索生成约2000个区域提议
- 对每个区域进行CNN特征提取
- 使用SVM分类
- 边界框回归精修位置
缺点:重复计算多,速度慢
Fast R-CNN(2015)改进:
- 整图通过CNN提取特征
- 通过RoI Pooling将不同大小的提议映射为固定尺寸
- 多任务损失(分类+回归)
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGeneratordef create_faster_rcnn_model(num_classes):# 加载预训练的主干网络backbone = torchvision.models.mobilenet_v2(pretrained=True).featuresbackbone.out_channels = 1280# 定义RPN的anchor生成器anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios=((0.5, 1.0, 2.0),))# 定义RoI poolingroi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size=7,sampling_ratio=2)# 组装Faster R-CNN模型model = FasterRCNN(backbone,num_classes=num_classes,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)return model
Faster R-CNN(2015)关键创新:
- 用RPN(Region Proposal Network)替代选择性搜索
- 实现端到端训练
- 速度和精度进一步提升
3.1.2 FPN(特征金字塔网络)
解决多尺度检测问题:
- 自顶向下路径融合不同层级的特征
- 不同尺度的物体在不同层级检测
- 显著提升小物体检测性能
3.1.3 Mask R-CNN
扩展Faster R-CNN:
- 增加分割分支
- 用RoI Align替代RoI Pooling(解决错位问题)
- 同时输出检测框和实例掩码
3.2 单阶段检测器
单阶段检测器直接预测物体类别和位置,速度更快但精度通常略低于两阶段方法。
3.2.1 YOLO系列
YOLO(You Only Look Once)核心思想:
- 将图像划分为S×S网格
- 每个网格预测B个边界框和置信度
- 同时预测类别概率
- 端到端训练
YOLOv3改进:
- 多尺度预测(类似FPN)
- 更好的主干网络(Darknet-53)
- 使用逻辑回归预测对象分数
# YOLOv3模型定义示例
class YOLOv3(nn.Module):def __init__(self, num_classes, anchors):super(YOLOv3, self).__init__()self.num_classes = num_classesself.anchors = anchors# 主干网络self.backbone = Darknet53()# 检测头self.detect_head = nn.Sequential(# 包含多个卷积层和上采样# 输出三个尺度的特征图)def forward(self, x):# 提取特征features = self.backbone(x)# 多尺度预测outputs = self.detect_head(features)return outputs
YOLOv4/v5创新:
- 大量训练技巧(Mosaic数据增强、CIoU损失等)
- 更高效的网络设计
- 自注意力机制引入
3.2.2 SSD(Single Shot MultiBox Detector)
关键特点:
- 在不同层级的特征图上预测
- 使用不同比例的默认框(Default Box)
- 平衡速度和精度
3.2.3 RetinaNet
解决类别不平衡问题:
- 提出Focal Loss
- 对难样本赋予更大权重
- 保持单阶段速度的同时达到两阶段精度
3.3 Anchor-free方法
摆脱预定义anchor的限制,直接预测关键点或中心点。
3.3.1 CornerNet
创新点:
- 检测物体左上和右下角点
- 使用角点配对匹配物体
- 引入角点池化层
3.3.2 CenterNet
改进思路:
- 检测物体中心点
- 回归物体大小
- 简化检测流程
3.3.3 FCOS(Fully Convolutional One-Stage)
全卷积方法:
- 逐像素预测
- 中心度(Centerness)评分
- 多层级预测
3.4 基于Transformer的检测器
3.4.1 DETR(Detection Transformer)
开创性工作:
- 使用Transformer编码器-解码器架构
- 二分图匹配损失
- 完全端到端,无需NMS后处理
from transformers import DetrForObjectDetectiondef create_detr_model(num_classes):model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50",num_labels=num_classes,ignore_mismatched_sizes=True)return model
3.4.2 Swin Transformer
层次化设计:
- 移动窗口自注意力
- 计算效率高
- 适合密集预测任务
3.4.3 Deformable DETR
改进DETR:
- 可变形注意力机制
- 更快收敛
- 更好处理小物体
3.5 目标检测的关键技术
3.5.1 损失函数
-
分类损失:
- 交叉熵损失
- Focal Loss(处理不平衡)
-
定位损失:
- Smooth L1损失
- IoU损失系列(GIoU、DIoU、CIoU)
-
匹配策略:
- 二分图匹配(匈牙利算法)
- Anchor匹配(IoU阈值)
3.5.2 后处理技术
-
非极大值抑制(NMS):
- 抑制冗余检测框
- 保留最高得分检测
-
Soft-NMS:
- 连续降低重叠框分数
- 减少误删
-
自适应NMS:
- 动态调整抑制阈值
3.5.3 数据增强
-
基础增强:
- 随机翻转、裁剪、颜色抖动
-
高级增强:
- Mosaic(YOLOv4)
- MixUp
- CutMix
-
领域特定增强:
- 针对小物体、遮挡等的增强策略
第四章 目标检测评估与优化
准确评估目标检测模型的性能并持续优化是实际应用中的关键环节。本章将详细介绍评估指标、优化策略以及常见问题的解决方案。
4.1 评估指标
4.1.1 准确率指标
-
精确率(Precision):
- 正确检测占所有检测的比例
- TP / (TP + FP)
-
召回率(Recall):
- 正确检测占所有真实目标的比例
- TP / (TP + FN)
-
平均精度(AP):
- 不同召回率下的精确率平均值
- PASCAL VOC:11点插值法
- COCO:101点插值法
-
mAP(mean Average Precision):
- 所有类别AP的平均值
- 主要综合评估指标
4.1.2 定位指标
-
IoU(Intersection over Union):
- 检测框与真实框的交并比
- 常用阈值:0.5、0.75
-
定位误差:
- 中心点距离
- 宽高比例差异
4.1.3 速度指标
-
FPS(Frames Per Second):
- 每秒处理的图像数量
- 实际部署关键指标
-
延迟(Latency):
- 单张图像处理时间
-
FLOPs(Floating Point Operations):
- 计算复杂度
- 反映理论计算量
4.1.4 COCO评估标准
MS COCO数据集提出的综合评估:
- AP@[.5:.95]:IoU从0.5到0.95的平均AP
- AP@.5:宽松评估(IoU=0.5)
- AP@.75:严格评估(IoU=0.75)
- APS、APM、AP^L:小、中、大物体的AP
- AR:平均召回率
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOevaldef evaluate_coco(dataset, model, threshold=0.05):results = []for image_id in dataset.img_ids:# 加载图像image_info = dataset.loadImgs(image_id)[0]image_path = f"{dataset.img_dir}/{image_info['file_name']}"# 运行检测detections = model.detect(image_path)# 转换为COCO格式for det in detections:results.append({'image_id': image_id,'category_id': det['category_id'],'bbox': [det['x'], det['y'], det['w'], det['h']],'score': det['score']})# 加载标注coco_true = dataset.cocococo_pred = coco_true.loadRes(results)# 运行评估coco_eval = COCOeval(coco_true, coco_pred, 'bbox')coco_eval.evaluate()coco_eval.accumulate()coco_eval.summarize()return coco_eval.stats
4.2 模型优化策略
4.2.1 轻量化设计
-
高效主干网络:
- MobileNet
- ShuffleNet
- EfficientNet
-
模型压缩技术:
- 量化(8位/4位)
- 剪枝(结构化/非结构化)
- 知识蒸馏
-
架构优化:
- 深度可分离卷积
- 通道注意力
- 神经架构搜索
4.2.2 训练优化
-
数据增强策略:
- 针对特定场景定制增强
- AutoAugment学习最优策略
-
损失函数设计:
- 改进定位损失(如EIoU)
- 类别平衡损失
-
训练技巧:
- 预热学习率
- 标签平滑
- 模型EMA
4.2.3 后处理优化
-
NMS改进:
- Cluster-NMS
- Matrix-NMS
- 自适应阈值NMS
-
结果融合:
- 多模型集成
- 测试时增强(TTA)
-
延迟优化:
- 流水线处理
- 模型分片
4.3 常见问题与解决方案
4.3.1 小物体检测
挑战:
- 小物体特征信息少
- 容易被背景干扰
- 在特征图上分辨率低
解决方案:
- 高分辨率特征图(如FPN)
- 特征融合(如PANet)
- 针对性数据增强(小物体复制)
- 专用检测头(更小的anchor)
4.3.2 类别不平衡
挑战:
- 某些类别样本极少
- 模型偏向多数类
- 难样本挖掘困难
解决方案:
- 重采样(过采样/欠采样)
- 类别加权损失
- Focal Loss
- 渐进式训练
4.3.3 遮挡处理
挑战:
- 物体部分不可见
- 特征不完整
- 容易误检或漏检
解决方案:
- 上下文信息利用
- 部分匹配策略
- 可见性预测分支
- 关系建模(如Transformer)
4.3.4 跨域适应
挑战:
- 训练和测试数据分布不同
- 领域偏移导致性能下降
- 目标域标注数据少
解决方案:
- 领域对抗训练
- 风格迁移
- 自训练(Self-training)
- 测试时适应
第五章 目标检测应用实践
目标检测技术已广泛应用于各个行业和场景。本章将介绍典型应用案例,并提供实践指导和代码示例,帮助读者将理论知识转化为实际解决方案。
5.1 典型应用场景
5.1.1 智能安防与监控
-
人脸检测与识别:
- 出入口控制
- 重点人员布控
- 人群密度分析
-
异常行为检测:
- 打架斗殴识别
- 跌倒检测
- 可疑物品遗留
-
交通监控:
- 违章检测
- 车牌识别
- 交通流量统计
5.1.2 自动驾驶
-
环境感知:
- 车辆、行人检测
- 交通标志识别
- 可行驶区域分割
-
多传感器融合:
- 摄像头+雷达+LiDAR
- 时空信息融合
-
实时决策支持:
- 碰撞预警
- 自动紧急制动
5.1.3 工业质检
-
缺陷检测:
- 表面划痕
- 装配完整性
- 异物检测
-
自动化分拣:
- 物品分类
- 质量分级
-
流程监控:
- 生产线异常检测
- 工人操作合规性检查
5.1.4 医疗影像分析
-
病灶检测:
- 肺结节检测
- 肿瘤定位
- 骨折识别
-
医疗辅助:
- 手术器械追踪
- 器官定位
- 细胞计数
-
诊断支持:
- 异常区域标记
- 量化分析
5.1.5 零售与电商
-
智能货架:
- 商品识别
- 缺货检测
- 价格标签核对
-
顾客行为分析:
- 动线追踪
- 停留热点分析
- 拿取行为识别
-
视觉搜索:
- 拍照购物
- 相似商品推荐
5.2 实践指导
5.2.1 数据准备与标注
-
数据收集原则:
- 多样性:不同视角、光照、背景
- 代表性:覆盖实际场景的各类情况
- 平衡性:类别分布尽量均衡
-
标注工具选择:
- LabelImg:简单易用的矩形标注
- CVAT:功能丰富的在线工具
- LabelMe:支持多边形标注
- 商业平台:Scale AI、Supervisely
-
标注规范制定:
- 明确标注边界(如包含/不包含哪些部分)
- 处理遮挡情况的规则
- 多级分类体系设计
# 使用LabelImg生成的XML转换为COCO格式示例
import xml.etree.ElementTree as ET
import jsondef convert_voc_to_coco(voc_annotations, output_file):coco = {"images": [],"annotations": [],"categories": []}# 添加类别categories = set()for ann in voc_annotations:tree = ET.parse(ann)for elem in tree.iterfind('object/name'):categories.add(elem.text)coco["categories"] = [{"id": i+1, "name": name} for i, name in enumerate(sorted(categories))]# 转换标注ann_id = 1for img_id, ann in enumerate(voc_annotations, 1):tree = ET.parse(ann)root = tree.getroot()# 添加图像信息size = root.find('size')image_info = {"id": img_id,"file_name": root.find('filename').text,"width": int(size.find('width').text),"height": int(size.find('height').text)}coco["images"].append(image_info)# 添加标注信息for obj in root.iter('object'):cat_name = obj.find('name').textcat_id = next(c['id'] for c in coco['categories'] if c['name'] == cat_name)bbox = obj.find('bndbox')xmin = float(bbox.find('xmin').text)ymin = float(bbox.find('ymin').text)xmax = float(bbox.find('xmax').text)ymax = float(bbox.find('ymax').text)width = xmax - xminheight = ymax - yminannotation = {"id": ann_id,"image_id": img_id,"category_id": cat_id,"bbox": [xmin, ymin, width, height],"area": width * height,"iscrowd": 0}coco["annotations"].append(annotation)ann_id += 1# 保存COCO格式with open(output_file, 'w') as f:json.dump(coco, f)
5.2.2 模型选择指南
根据应用需求选择合适的检测模型:
需求场景 | 推荐模型 | 理由 |
---|---|---|
高精度 | Faster R-CNN、Cascade R-CNN | 两阶段方法精度高 |
实时性 | YOLOv5、YOLOX、NanoDet | 优化过的单阶段方法 |
移动端 | MobileDet、YOLO-Lite | 轻量级设计 |
小物体 | FPN、PANet | 多尺度特征融合 |
遮挡场景 | RelationNet、DETR | 关系建模能力强 |
多类别 | RetinaNet、ATSS | 处理类别不平衡好 |
5.2.3 训练技巧
-
学习率策略:
- 线性预热
- 余弦退火
- 多阶段衰减
-
损失函数选择:
- 分类:Focal Loss(类别不平衡时)
- 回归:GIoU/SIoU(更好的框回归)
-
数据增强组合:
- 基础增强:翻转、旋转、裁剪
- 高级增强:Mosaic、MixUp
- 领域特定增强
-
正则化方法:
- DropBlock
- 标签平滑
- 权重衰减
5.2.4 部署优化
-
模型转换:
- PyTorch → ONNX → TensorRT
- TorchScript序列化
- 量化感知训练
-
推理加速:
- 半精度推理(FP16)
- 层融合优化
- 内存访问优化
-
边缘部署:
- 模型剪枝
- 知识蒸馏
- 专用加速芯片(NPU)
# 使用TensorRT加速YOLOv5推理示例
import torch
import tensorrt as trtdef export_to_onnx(model, sample_input, onnx_path):torch.onnx.export(model,sample_input,onnx_path,opset_version=11,input_names=['images'],output_names=['output'])def build_engine(onnx_path, engine_path):logger = trt.Logger(trt.Logger.INFO)builder = trt.Builder(logger)network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))parser = trt.OnnxParser(network, logger)with open(onnx_path, 'rb') as model:if not parser.parse(model.read()):for error in range(parser.num_errors):print(parser.get_error(error))config = builder.create_builder_config()config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)serialized_engine = builder.build_serialized_network(network, config)with open(engine_path, 'wb') as f:f.write(serialized_engine)
5.3 完整案例:交通标志检测
5.3.1 数据集准备
使用德国交通标志检测基准数据集(GTSDB):
- 900张图像
- 43类交通标志
- 标注格式:PASCAL VOC
5.3.2 模型训练
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from engine import train_one_epoch, evaluate
import utilsdef train_traffic_sign_detector(dataset_train, dataset_test):# 加载预训练模型backbone = torchvision.models.mobilenet_v2(pretrained=True).featuresbackbone.out_channels = 1280# 定义anchor生成器anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256),),aspect_ratios=((0.5, 1.0, 2.0),))# 定义RoI poolingroi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size=7,sampling_ratio=2)# 创建Faster R-CNN模型model = FasterRCNN(backbone,num_classes=43, # 43类交通标志rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)# 数据加载器data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=4, shuffle=True,collate_fn=utils.collate_fn)data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=2, shuffle=False,collate_fn=utils.collate_fn)# 优化器params = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)# 学习率调度器lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)# 训练循环num_epochs = 10for epoch in range(num_epochs):train_one_epoch(model, optimizer, data_loader_train, torch.device('cuda'), epoch, print_freq=10)lr_scheduler.step()evaluate(model, data_loader_test, device=torch.device('cuda'))return model
5.3.3 模型评估
def evaluate_model(model, data_loader):model.eval()results = []with torch.no_grad():for images, targets in data_loader:images = list(img.to('cuda') for img in images)outputs = model(images)for i, output in enumerate(outputs):boxes = output['boxes'].cpu().numpy()scores = output['scores'].cpu().numpy()labels = output['labels'].cpu().numpy()for box, score, label in zip(boxes, scores, labels):if score > 0.5: # 置信度阈值results.append({'image_id': targets[i]['image_id'].item(),'category_id': label.item(),'bbox': [box[0], box[1], box[2]-box[0], box[3]-box[1]],'score': score.item()})# 转换为COCO评估格式并计算mAPcoco_gt = data_loader.dataset.cocococo_dt = coco_gt.loadRes(results)coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')coco_eval.evaluate()coco_eval.accumulate()coco_eval.summarize()return coco_eval.stats
5.3.4 部署应用
import cv2
import numpy as npclass TrafficSignDetector:def __init__(self, model_path):self.model = torch.load(model_path)self.model.eval()self.class_names = [...] # 43类交通标志名称self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])def detect(self, image_path, conf_thresh=0.5):# 读取图像image = cv2.imread(image_path)image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 预处理image_tensor = self.transform(image_rgb)image_tensor = image_tensor.unsqueeze(0).to('cuda')# 推理with torch.no_grad():outputs = self.model(image_tensor)# 后处理boxes = outputs[0]['boxes'].cpu().numpy()scores = outputs[0]['scores'].cpu().numpy()labels = outputs[0]['labels'].cpu().numpy()# 绘制结果for box, score, label in zip(boxes, scores, labels):if score > conf_thresh:x1, y1, x2, y2 = map(int, box)cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)label_text = f"{self.class_names[label]}: {score:.2f}"cv2.putText(image, label_text, (x1, y1-10),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)return image