YOLO 11 图像分类推理 Web 服务
YOLO 11 图像分类推理 Web 服务
flyfish
import os
import io
import uuid
import base64
import time
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from flask import Flask, request, jsonify
import datetime# ==================== 配置参数 - 在此处修改配置 ====================
MODEL_PATH = "/home/user/yolo/runs/classify/train/weights/best.pt" # 模型路径
OUTPUT_BASE = "inference_results" # 结果保存的基础目录
IMGSZ = 320 # 图像尺寸
MAX_PER_FOLDER = 5000 # 每个文件夹最多存放的图片数量
PORT = 5000 # Web服务端口
HOST = "0.0.0.0" # 监听地址,0.0.0.0表示允许所有网络访问
# ==================================================================# 初始化Flask应用
app = Flask(__name__)# 全局变量,用于存储模型(只加载一次)
model = None# 确保中文正常显示
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['AR PL UMing CN']
plt.rcParams['axes.unicode_minus'] = Falsedef load_model():"""加载YOLO模型,只在服务启动时调用一次"""global modeltry:model = YOLO(MODEL_PATH)print(f"成功加载模型: {MODEL_PATH}")return Trueexcept Exception as e:print(f"模型加载失败: {e}")return Falsedef preprocess_image(image, imgsz=IMGSZ):"""图像预处理,与训练时保持一致"""# 转换为RGB格式img = image.convert('RGB')# 定义与训练时相同的预处理步骤transform = T.Compose([T.Resize((imgsz, imgsz)),T.ToTensor(),T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),])# 预处理图像img_tensor = transform(img)# 添加批次维度img_tensor = img_tensor.unsqueeze(0)return img_tensor, imgdef create_output_directory(base_dir, class_id, max_per_folder=MAX_PER_FOLDER):"""创建输出目录,当图像数量超过max_per_folder时创建新的子目录"""# 基础类别目录class_dir = os.path.join(base_dir, f"class_{class_id}")# 检查是否需要创建子目录,从batch_0开始subdir_index = 0while True:current_dir = os.path.join(class_dir, f"batch_{subdir_index}")# 确保目录存在os.makedirs(current_dir, exist_ok=True)# 计算当前目录中的文件数量image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']file_count = 0for ext in image_extensions:file_count += len(os.listdir(current_dir)) if os.path.exists(current_dir) else 0if file_count < max_per_folder:return current_dirsubdir_index += 1def save_result_image(original_img, result, image_id):"""保存带有推理结果的图像,文件名包含当前时间、图像id和不重复字符串"""# 创建可绘制的图像副本draw_img = original_img.copy()draw = ImageDraw.Draw(draw_img)# 获取结果信息class_id = result.probs.top1confidence = result.probs.top1conf.item()# 准备要显示的文本(不含class_name)text = f"(ID: {class_id}): {confidence:.4f}"# 设置字体(尝试使用系统字体)try:font = ImageFont.truetype("/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", 16,index=0)except Exception as e:print(f"文泉驿字体加载失败: {e},将使用默认字体")font = ImageFont.load_default()# 在图像上绘制文本背景和文本text_bbox = draw.textbbox((10, 10), text, font=font)draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill="white")draw.text((10, 10), text, font=font, fill=(255, 0, 0)) # 红色文本# 生成文件名:当前时间_图像id_不重复字符串current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")unique_str = str(uuid.uuid4())[:8] # 取UUID的前8位作为不重复字符串file_name = f"{current_time}_{image_id}_{unique_str}.jpg"# 确定输出目录output_dir = create_output_directory(OUTPUT_BASE, class_id)# 保存图像output_path = os.path.join(output_dir, file_name)draw_img.save(output_path)return output_pathdef process_image(image, image_id):"""处理单张图像并返回推理结果"""try:# 预处理图像img_tensor, original_img = preprocess_image(image)# 推理results = model(img_tensor)# 解析结果result = results[0]class_id = result.probs.top1 # 最可能的类别IDconfidence = result.probs.top1conf.item() # 对应的置信度# 保存结果图像output_path = save_result_image(original_img, result, image_id)# 记录结果到文本文件current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")with open(result_file, 'a', encoding='utf-8') as f:result_str = f"[{current_time}] 图像ID: {image_id}\n"result_str += f" 预测ID: {class_id}\n"result_str += f" 置信度: {confidence:.4f}\n"result_str += f" 保存路径: {output_path}\n"result_str += "-"*50 + "\n"f.write(result_str)return {"success": True,"image_id": image_id,"current_time": current_time,"class_id": int(class_id),"confidence": float(confidence)}except Exception as e:error_msg = f"处理图像ID {image_id} 时出错: {str(e)}\n"print(error_msg)# 记录错误信息result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")with open(result_file, 'a', encoding='utf-8') as f:f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {error_msg}")return {"success": False,"image_id": image_id,"error": str(e)}@app.route('/infer', methods=['POST'])
def infer():"""API接口:接收图像ID和base64编码的图像,返回推理结果"""# 检查模型是否已加载if model is None:return jsonify({"success": False,"error": "模型未加载,请检查服务状态"}), 500# 获取请求数据data = request.json# 验证请求数据if not data or 'image_id' not in data or 'image_base64' not in data:return jsonify({"success": False,"error": "请求数据缺少image_id或image_base64字段"}), 400try:# 解码base64图像image_data = base64.b64decode(data['image_base64'])image = Image.open(io.BytesIO(image_data))# 处理图像result = process_image(image, data['image_id'])return jsonify(result)except Exception as e:return jsonify({"success": False,"image_id": data.get('image_id'),"error": f"处理图像时出错: {str(e)}"}), 500@app.route('/health', methods=['GET'])
def health_check():"""健康检查接口"""if model is not None:return jsonify({"status": "healthy","model_loaded": True,"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})else:return jsonify({"status": "unhealthy","model_loaded": False,"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}), 500if __name__ == "__main__":# 创建输出目录os.makedirs(OUTPUT_BASE, exist_ok=True)# 加载模型model_loaded = load_model()# 启动Web服务if model_loaded:print(f"服务启动,监听 {HOST}:{PORT}")app.run(host=HOST, port=PORT, threaded=True) # threaded=True支持多线程处理请求else:print("模型加载失败,无法启动服务")
使用方式
- 启动服务:
python yolo11_web_service.py
- 客户端发送请求示例(使用Python):
import requests
import base64# 读取图片并转换为base64
with open("test.jpg", "rb") as f:image_base64 = base64.b64encode(f.read()).decode('utf-8')# 准备请求数据
data = {"image_id": "test_001","image_base64": image_base64
}# 发送请求
response = requests.post("http://localhost:5000/infer", json=data)
print(response.json())