基于 fastapi 的 YOLO 批量目标检测 API:支持单图 / 文件夹自适应处理
项目背景
在计算机视觉任务中,我们经常需要对大量图片进行目标检测和标注。YOLO 系列模型凭借其高效性成为目标检测的首选工具之一,但批量处理图片时往往需要编写繁琐的脚本。本文将介绍一个基于 Flask 和 YOLOv11 的 API 服务,支持单张图片和文件夹批量处理,可自定义置信度、交并比等参数,并能返回详细的标注统计结果。
功能特点
- 支持单张图片和文件夹批量处理,自动识别输入类型
- 可自定义置信度阈值 (conf) 和交并比阈值 (iou)
- 自动选择运行设备 (GPU 优先,无则 CPU)
- 生成标注后的图片和检测结果 TXT 文件
- 返回详细的标注统计信息 (每个文件的目标类别及数量)
- 提供完整的任务状态查询和结果下载功能
技术栈
fastapi
:用于构建高性能的 Web API。uvicorn
:一个快速的 ASGI 服务器,用于运行 FastAPI 应用。pydantic
:用于数据验证和设置类型提示。ultralytics
:包含 YOLO 模型,用于目标检测。opencv-python
:用于图像处理和计算机视觉任务。numpy
:用于数值计算。pillow
:Python Imaging Library,用于图像处理。torch
:PyTorch 深度学习框架,YOLO 模型依赖于此。base64
:用于 Base64 编码和解码,虽然是 Python 标准库,但为了完整性列出。
代码实现
完整代码
import os
import shutil
import time
import json
import logging
import cv2
import numpy as np
import base64
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
import torch
import threading# 配置日志,设置日志级别为INFO,方便后续调试和监控程序运行状态
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)# 初始化FastAPI应用,设置API的标题和版本
app = FastAPI(title="YOLO目标检测API", version="1.0")# 配置CORS(跨域资源共享),允许所有来源的请求,支持所有请求方法和头部信息
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)# 任务状态跟踪,用于存储每个任务的执行状态、进度等信息
tasks: Dict[str, Dict[str, Any]] = {}# 颜色和字体配置,定义用于绘制检测框和标签的颜色列表
COLORS = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 255, 0), (0, 255, 255),(255, 0, 255), (128, 0, 0), (0, 128, 0), (0, 0, 128)
]
try:# 尝试加载Arial字体,字号为18font = ImageFont.truetype("arial.ttf", 18)
except:# 如果加载失败,使用默认字体font = ImageFont.load_default()# 请求模型,定义检测请求的参数结构
class DetectRequest(BaseModel):input_path: str # 输入文件或文件夹路径output_dir: str = "demo" # 输出目录,默认为demomodel_path: str = "yolo11n.pt" # 模型路径,默认为yolo11n.ptdevice: Optional[str] = None # 设备,可选参数conf: float = 0.25 # 置信度阈值,默认为0.25iou: float = 0.7 # IOU阈值,默认为0.7target_classes: Optional[str] = None # 目标类别,逗号分隔,可选参数def draw_annotations(image: np.ndarray, results) -> np.ndarray:"""在图像上绘制检测框和类别标签:param image: 输入的图像数组:param results: 检测结果:return: 绘制标注后的图像数组"""# 将OpenCV的BGR格式图像转换为PIL的RGB格式图像frame_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))draw = ImageDraw.Draw(frame_pil)result = results[0]boxes = result.boxesif len(boxes) == 0:# 如果没有检测到目标,直接返回原始图像return image# 获取检测框的类别ID和置信度class_ids = boxes.cls.cpu().numpy().astype(int)confidences = boxes.conf.cpu().numpy()class_names = result.namesfor box, class_id, conf in zip(boxes.xyxy.cpu().numpy(), class_ids, confidences):x1, y1, x2, y2 = map(int, box)color = COLORS[class_id % len(COLORS)]# 绘制边界框draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)# 绘制标签label = f"{class_names[class_id]}: {conf:.2f}"try:# 获取标签文本的边界框text_bbox = draw.textbbox((x1, y1), label, font=font)except AttributeError:# 如果不支持textbbox方法,使用textsize方法获取文本宽度和高度text_width, text_height = draw.textsize(label, font=font)text_bbox = (x1, y1, x1 + text_width, y1 + text_height)text_height = text_bbox[3] - text_bbox[1]# 计算标签的起始y坐标label_y1 = y1 - text_height - 5 if y1 - text_height - 5 > 0 else y1 + 5# 绘制标签背景draw.rectangle([(x1, label_y1), (x1 + (text_bbox[2] - text_bbox[0]), label_y1 + text_height)],fill=color)# 绘制标签文本draw.text((x1, label_y1), label, font=font, fill=(255, 255, 255))# 将PIL图像转换回OpenCV的BGR格式图像return cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)def batch_detect_and_annotate(task_id: str,input_path: str,output_dir: str = "demo",model_path: str = "yolo11n.pt",device: Optional[str] = None,conf: float = 0.25,iou: float = 0.7,target_classes: Optional[str] = None
):"""处理单个图片或文件夹中的所有图片,生成标注后的图片和识别结果TXT:param task_id: 任务ID:param input_path: 输入文件或文件夹路径:param output_dir: 输出目录:param model_path: 模型路径:param device: 设备:param conf: 置信度阈值:param iou: IOU阈值:param target_classes: 目标类别,逗号分隔"""# 初始化任务状态为运行中tasks[task_id] = {"status": "running", "progress": 0, "message": "开始处理..."}try:start_time = time.time()# 创建输出目录,如果目录已存在则不会报错os.makedirs(output_dir, exist_ok=True)# 自动选择设备,如果未指定设备,则优先使用GPU,若GPU不可用则使用CPUselected_device = deviceif selected_device is None:selected_device = '0' if torch.cuda.is_available() else 'cpu'# 加载模型try:model = YOLO(model_path)except Exception as e:# 若模型加载失败,更新任务状态为失败并记录错误信息tasks[task_id] = {"status": "failed", "message": f"模型加载失败:{str(e)}"}return# 解析目标类别target_set = Noneif target_classes:target_set = set([cls.strip() for cls in target_classes.split(',')])model_classes = set(model.names.values())invalid_classes = [cls for cls in target_set if cls not in model_classes]if invalid_classes:# 若存在无效的目标类别,更新任务状态为失败并记录错误信息tasks[task_id] = {"status": "failed", "message": f"无效的目标类别: {', '.join(invalid_classes)}"}return# 判断输入类型if os.path.isfile(input_path):# 若输入为文件,获取文件名image_files = [os.path.basename(input_path)]input_dir = os.path.dirname(input_path)is_single_file = Trueelif os.path.isdir(input_path):# 若输入为文件夹,获取文件夹中的所有图片文件input_dir = input_pathimage_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff")image_files = [f for f in os.listdir(input_dir)if f.lower().endswith(image_extensions) and os.path.isfile(os.path.join(input_dir, f))]is_single_file = Falseelse:# 若输入路径不存在,更新任务状态为失败并记录错误信息tasks[task_id] = {"status": "failed", "message": f"错误:输入路径不存在 - {input_path}"}returnif not image_files:# 若未找到图片文件,更新任务状态为失败并记录错误信息tasks[task_id] = {"status": "failed", "message": f"错误:未找到图片文件 - {input_path}"}returntotal_files = len(image_files)success_count = 0fail_count = 0file_annotations = {}for i, img_file in enumerate(image_files, 1):img_path = os.path.join(input_dir, img_file)img_name = os.path.splitext(img_file)[0]# 更新进度progress = int((i / total_files) * 100)tasks[task_id]["progress"] = progresstasks[task_id]["message"] = f"正在处理:{img_file}"try:# 执行检测results = model(img_path, device=selected_device, conf=conf, iou=iou)# 筛选目标类别filtered_boxes = []if target_set:for box in results[0].boxes:cls_name = model.names[int(box.cls)]if cls_name in target_set:filtered_boxes.append(box)results[0].boxes = torch.stack([box for box in filtered_boxes]) if filtered_boxes else torch.empty(0, 6)else:filtered_boxes = results[0].boxes# 生成标注图片annotated_img = results[0].plot() # BGR格式pil_img = Image.fromarray(annotated_img[..., ::-1]) # 转换为RGB# 生成输出文件名output_img_name = f"{img_name}_annotated.jpg"output_txt_name = f"{img_name}_detections.txt"output_img_path = os.path.join(output_dir, output_img_name)txt_path = os.path.join(output_dir, output_txt_name)# 保存标注图片pil_img.save(output_img_path)# 保存检测结果with open(txt_path, "w", encoding="utf-8") as f:for box in filtered_boxes:cls_name = model.names[int(box.cls)]confidence = round(float(box.conf), 4)x1, y1, x2, y2 = map(round, box.xyxy[0].tolist())f.write(f"{cls_name} {confidence} {x1} {y1} {x2} {y2}\n")# 统计标注类型annotations = {}for box in filtered_boxes:cls_name = model.names[int(box.cls)]annotations[cls_name] = annotations.get(cls_name, 0) + 1# 保存统计结果file_annotations[img_name] = {"annotated_image": output_img_path,"detection_txt": txt_path,"class_counts": annotations}success_count += 1except Exception as e:fail_count += 1logger.error(f"处理{img_file}失败: {str(e)}")# 计算总耗时total_time = round(time.time() - start_time, 2)# 更新任务状态为完成tasks[task_id] = {"status": "completed", "progress": 100,"total_time": total_time,"success_count": success_count,"fail_count": fail_count,"total_files": total_files,"output_dir": os.path.abspath(output_dir),"input_path": input_path,"is_single_file": is_single_file,"parameters": {"confidence_threshold": conf,"iou_threshold": iou,"device": selected_device,"target_classes": list(target_set) if target_set else None},"annotations": file_annotations,"message": "处理完成"}except Exception as e:# 若出现未知错误,更新任务状态为失败并记录错误信息tasks[task_id] = {"status": "failed", "message": f"未知错误:{str(e)}"}@app.post("/detect")
async def detect(request: DetectRequest):"""接收参数并启动目标检测任务,同步返回结果:param request: 检测请求参数:return: 任务结果"""logger.info(f"收到检测请求: {request.input_path}")# 验证参数范围if not (0 <= request.conf <= 1):raise HTTPException(status_code=400, detail="conf参数必须在0-1之间")if not (0 <= request.iou <= 1):raise HTTPException(status_code=400, detail="iou参数必须在0-1之间")# 生成唯一任务IDtask_id = str(int(time.time() * 1000))logger.info(f"创建任务: {task_id}")try:# 直接执行检测任务(同步模式)batch_detect_and_annotate(task_id, request.input_path, request.output_dir, request.model_path, request.device, request.conf, request.iou, request.target_classes)# 获取任务结果task_result = tasks.get(task_id)if not task_result:logger.error(f"任务执行失败,未获取到结果: {task_id}")raise HTTPException(status_code=500, detail="任务执行失败,未获取到结果")if task_result["status"] == "failed":logger.error(f"任务失败: {task_id}, 原因: {task_result['message']}")return JSONResponse(status_code=400,content={"task_id": task_id,"status": "failed","message": task_result["message"]})logger.info(f"任务完成: {task_id}, 处理时间: {task_result['total_time']}秒")# 返回完整结果return task_resultexcept Exception as e:logger.exception(f"请求处理失败: {str(e)}")raise HTTPException(status_code=500, detail=f"请求处理失败: {str(e)}")@app.get("/status/{task_id}")
async def get_status(task_id: str):"""获取任务状态:param task_id: 任务ID:return: 任务状态信息"""status = tasks.get(task_id, {"status": "not_found", "message": "任务ID不存在"})return status@app.get("/results/{task_id}")
async def get_results(task_id: str):"""获取任务结果文件列表和标注统计:param task_id: 任务ID:return: 任务结果信息"""task = tasks.get(task_id)if not task:raise HTTPException(status_code=404, detail="任务ID不存在")if task["status"] != "completed":return {"status": task["status"],"progress": task["progress"],"message": task["message"],"error": "任务未完成,无法获取结果"}return {"task_id": task_id,"status": "completed","total_time": task["total_time"],"success_count": task["success_count"],"fail_count": task["fail_count"],"total_files": task["total_files"],"input_path": task["input_path"],"is_single_file": task["is_single_file"],"output_dir": task["output_dir"],"parameters": task["parameters"],"annotations": task["annotations"],"message": "处理完成"}@app.get("/download/{task_id}/{filename:path}")
async def download_file(task_id: str, filename: str):"""下载结果文件:param task_id: 任务ID:param filename: 文件名:return: 文件响应"""task = tasks.get(task_id)if not task or task["status"] != "completed":raise HTTPException(status_code=400, detail="任务未完成或不存在")output_dir = task["output_dir"]file_path = os.path.join(output_dir, filename)if not os.path.isfile(file_path):raise HTTPException(status_code=404, detail="文件不存在")return FileResponse(path=file_path,filename=os.path.basename(filename),media_type="application/octet-stream")@app.websocket("/ws/video_detection")
async def detect_video_websocket(websocket: WebSocket):"""通过 WebSocket 处理实时视频帧检测。:param websocket: WebSocket连接"""await websocket.accept()logging.info("WebSocket 连接已建立。")try:while True:data_str = await websocket.receive_text()data = json.loads(data_str)model_name = data['model_name']base64_str = data['image_base64']conf = data.get('conf', 60) / 100.0iou = data.get('iou', 65) / 100.0model = YOLO(model_name)# --- **更稳健的解码逻辑 (核心修复)** ---try:header, encoded_data = base64_str.split(",", 1)# 1. 检查数据部分是否为空if not encoded_data:logging.warning("接收到不完整的Base64数据(数据部分为空),已跳过。")continueimage_bytes = base64.b64decode(encoded_data)# 2. 再次检查解码后的字节是否为空if not image_bytes:logging.warning("Base64解码后数据为空,已跳过。")continueexcept (ValueError, TypeError, IndexError) as e:logging.warning(f"无法解析Base64字符串: {e},已跳过。")continue# 3. 解码成OpenCV图像image_cv2 = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)# 4. 最终检查解码后的图像是否有效if image_cv2 is None:logging.warning("cv2.imdecode未能解码图像,可能数据已损坏,已跳过。")continueresults = model.predict(source=image_cv2, conf=conf, iou=iou, verbose=False)annotated_image = draw_annotations(image_cv2, results)_, buffer = cv2.imencode('.jpg', annotated_image)result_base64 = base64.b64encode(buffer).decode("utf-8")await websocket.send_json({"image_base64": f"data:image/jpeg;base64,{result_base64}"})except WebSocketDisconnect:logging.info("WebSocket 客户端断开连接。")except Exception as e:error_message = f"WebSocket 处理错误: {type(e).__name__}"logging.error(f"{error_message} - {e}")await websocket.close(code=1011, reason=error_message)if __name__ == '__main__':import uvicornprint("启动YOLO目标检测API服务...")print("支持的API端点:")print(" POST /detect - 启动检测任务")print(" GET /status/<task_id> - 获取任务状态")print(" GET /results/<task_id> - 获取结果")print(" GET /download/<task_id>/<filename> - 下载结果文件")print(" WS /ws/video_detection - 实时视频帧检测")uvicorn.run(app, host="0.0.0.0", port=5000)
核心参数说明
参数名 | 类型 | 说明 | 默认值 |
input_path | String | 输入路径(支持单张图片或文件夹) | 无(必填) |
output_dir | String | 结果输出文件夹路径 | "demo" |
model_path | String | YOLO 模型路径 | "yolo11n.pt" |
device | String | 运行设备("cpu" 或 "0") | 自动选择 |
conf | Float | 置信度阈值(0-1) | 0.25 |
iou | Float | 交并比阈值(0-1) | 0.7 |
部署与使用
1. 安装依赖
pip install fastapi uvicorn pydantic ultralytics opencv-python numpy pillow torch base64
2. 启动服务
ython yolo_api.py
服务启动后会监听本地 5000 端口,输出如下:
启动YOLO目标检测API服务...
支持的API端点:
POST /detect - 启动检测任务并返回结果
GET /status/<task_id> - 获取任务状态
GET /results/<task_id> - 获取结果(与/detect相同)
GET /download/<task_id>/<filename> - 下载结果文件
* Running on http://0.0.0.0:5000/ (Press CTRL+C to quit)
3. 使用 Postman 调用 API
处理图片
- 请求 URL: http://localhost:5000/detect
- 请求方法: POST
- 请求体:
{
"input_path": "C:/Users/HUAWEI/Desktop/yolo/demo2/tupian/", // 待检测图片文件夹
"output_dir": "C:/Users/HUAWEI/Desktop/yolo/demo2/output", // 结果输出文件夹
"model_path": "yolo11n.pt", // 模型路径(默认会自动下载)
"device": null, // 自动选择设备(也可指定"cpu"或"0")
"conf": 0.1, // 置信度阈值(越高越严格,默认0.25)
"iou": 0.6, // 交并比阈值(越高越严格,默认0.7)
"target_classes" : "car"
}