当前位置: 首页 > news >正文

【map计算】自定义map计算

任务说明

我现在需要进行模型预测效果的评估,我选择map进行评估。
我的输入包含两个文件夹,一个是标准答案,里面有xml。
另外一个是你的答案,里面有你预测的xml。

模型会逐个object的进行对比,你可以自己把标准答案复制一份当做你的答案,然后进行测试,这样得分会是1.0000 。当你进行了一些修改后,得分会下降。

还可以指定针对哪些类别计算map。

我来帮你创建一个计算mAP的脚本,用于遍历文件夹a和b,并计算目标检测的综合得分。我创建了一个完整的mAP计算脚本,主要功能包括:

核心功能

  1. 文件夹遍历: 自动遍历文件夹a(预测结果)和文件夹b(真值标注)
  2. 多IoU阈值计算: 支持mAP@0.5、mAP@0.75、mAP@0.9的计算
  3. 综合得分: 按照你指定的权重公式计算:0.6*mAP50 + 0.2*mAP75 + 0.2*mAP90

关键特性

  • IoU计算: 实现了标准的边界框IoU计算
  • AP计算: 使用11点插值法计算Average Precision
  • 多格式支持: 支持COCO JSON格式的标注文件
  • 类别筛选: 可以指定目标类别列表进行计算

使用方法

  1. 修改配置:

    folder_a = "你的预测结果文件夹路径"
    folder_b = "你的真值标注文件夹路径"  
    target_classes = [1, 2, 3, 4, 5]  # 修改为你的目标类别ID
    
  2. 文件格式要求:

    • 支持COCO JSON格式
    • 预测结果需要包含置信度分数
    • 边界框格式:[x1, y1, x2, y2]
  3. 运行脚本:

    python map_calculator.py
    

代码

import os
import xml.etree.ElementTree as ET
import numpy as np
from collections import defaultdict
import globdef calculate_iou(box1, box2):"""计算两个边界框的IoUbox格式: [xmin, ymin, xmax, ymax]"""x1 = max(box1[0], box2[0])y1 = max(box1[1], box2[1])x2 = min(box1[2], box2[2])y2 = min(box1[3], box2[3])if x2 <= x1 or y2 <= y1:return 0.0intersection = (x2 - x1) * (y2 - y1)area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])union = area1 + area2 - intersectionreturn intersection / union if union > 0 else 0.0def parse_xml_annotation(xml_path):"""解析XML标注文件 (支持PASCAL VOC格式和自定义格式)返回格式: {'filename': str, 'objects': [{'name': str, 'bbox': [xmin, ymin, xmax, ymax], 'confidence': float}]}"""try:tree = ET.parse(xml_path)root = tree.getroot()annotation = {'filename': '','objects': []}# 获取文件名filename_elem = root.find('filename')if filename_elem is not None:annotation['filename'] = filename_elem.textelse:# 如果没有filename标签,使用xml文件名annotation['filename'] = os.path.splitext(os.path.basename(xml_path))[0]# 解析所有objectfor obj in root.findall('object'):obj_data = {}# 获取类别名称name_elem = obj.find('name')if name_elem is not None:obj_data['name'] = name_elem.textelse:continue  # 跳过没有名称的对象# 获取边界框坐标bndbox = obj.find('bndbox')if bndbox is not None:xmin = float(bndbox.find('xmin').text)ymin = float(bndbox.find('ymin').text)xmax = float(bndbox.find('xmax').text)ymax = float(bndbox.find('ymax').text)obj_data['bbox'] = [xmin, ymin, xmax, ymax]else:continue  # 跳过没有边界框的对象# 获取置信度 (预测结果才有,真值默认为1.0)confidence_elem = obj.find('confidence')if confidence_elem is not None:obj_data['confidence'] = float(confidence_elem.text)else:obj_data['confidence'] = 1.0  # 真值标注默认置信度为1.0annotation['objects'].append(obj_data)return annotationexcept Exception as e:print(f"解析XML文件失败 {xml_path}: {e}")return Nonedef calculate_ap_at_iou(predictions, ground_truths, iou_threshold=0.5):"""计算特定IoU阈值下的AP"""if not predictions:return 0.0if not ground_truths:return 0.0# 按置信度排序预测结果predictions = sorted(predictions, key=lambda x: x['confidence'], reverse=True)tp = np.zeros(len(predictions))fp = np.zeros(len(predictions))# 记录已匹配的ground truthmatched_gt = set()for i, pred in enumerate(predictions):best_iou = 0best_gt_idx = -1# 找到最佳匹配的ground truthfor j, gt in enumerate(ground_truths):if j in matched_gt:continueiou = calculate_iou(pred['bbox'], gt['bbox'])if iou > best_iou:best_iou = ioubest_gt_idx = j# 判断是否为真正例if best_iou >= iou_threshold and best_gt_idx != -1:tp[i] = 1matched_gt.add(best_gt_idx)else:fp[i] = 1# 计算累积的precision和recalltp_cumsum = np.cumsum(tp)fp_cumsum = np.cumsum(fp)recalls = tp_cumsum / len(ground_truths) if len(ground_truths) > 0 else np.zeros_like(tp_cumsum)precisions = np.divide(tp_cumsum, (tp_cumsum + fp_cumsum), out=np.zeros_like(tp_cumsum), where=(tp_cumsum + fp_cumsum) != 0)# 计算AP (使用11点插值法)ap = 0for t in np.arange(0, 1.1, 0.1):if np.sum(recalls >= t) == 0:p = 0else:p = np.max(precisions[recalls >= t])ap += p / 11return apdef load_xml_annotations(folder_path):"""加载文件夹中所有XML标注文件"""annotations = {}xml_files = glob.glob(os.path.join(folder_path, "*.xml"))for xml_file in xml_files:annotation = parse_xml_annotation(xml_file)if annotation:# 使用文件名作为key,去掉扩展名base_name = os.path.splitext(os.path.basename(xml_file))[0]annotations[base_name] = annotationreturn annotationsdef calculate_map_for_folders(pred_folder, gt_folder, target_classes, iou_thresholds=[0.5, 0.75, 0.9]):"""计算两个文件夹中XML标注的mAP"""print(f"加载预测结果文件夹: {pred_folder}")pred_annotations = load_xml_annotations(pred_folder)print(f"加载真值标注文件夹: {gt_folder}")gt_annotations = load_xml_annotations(gt_folder)print(f"找到预测文件: {len(pred_annotations)} 个")print(f"找到真值文件: {len(gt_annotations)} 个")results = {}# 为每个IoU阈值计算mAPfor iou_thresh in iou_thresholds:print(f"\n计算 mAP@{iou_thresh}...")class_aps = {}for class_name in target_classes:all_predictions = []all_ground_truths = []# 遍历所有匹配的文件common_files = set(pred_annotations.keys()) & set(gt_annotations.keys())for file_key in common_files:pred_data = pred_annotations[file_key]gt_data = gt_annotations[file_key]# 提取当前类别的预测和真值pred_class = [obj for obj in pred_data['objects'] if obj['name'] == class_name]gt_class = [obj for obj in gt_data['objects'] if obj['name'] == class_name]all_predictions.extend(pred_class)all_ground_truths.extend(gt_class)# 计算该类别的APap = calculate_ap_at_iou(all_predictions, all_ground_truths, iou_thresh)class_aps[class_name] = apprint(f"  {class_name}: AP = {ap:.4f} (预测:{len(all_predictions)}, 真值:{len(all_ground_truths)})")# 计算mAPmean_ap = np.mean(list(class_aps.values())) if class_aps else 0.0results[f'mAP{int(iou_thresh*100)}'] = mean_apresults[f'class_aps_{int(iou_thresh*100)}'] = class_apsprint(f"  mAP@{iou_thresh}: {mean_ap:.4f}")return resultsdef main():# 配置参数folder_a = r"E:\评分\我的答案"  # 预测结果XML文件夹folder_b = r"E:\评分\标准答案"  # 真值标注XML文件夹# 目标类别列表 - 根据你的数据集修改target_classes = ["person", "car", "bicycle", "motorcycle", "bus", "truck", "traffic_light", "stop_sign", "dog", "cat","021_tdhj_xxshywyh_sh/yw_yw"]print("=== XML格式mAP计算工具 ===")print(f"预测文件夹: {folder_a}")print(f"真值文件夹: {folder_b}")print(f"目标类别: {target_classes}")# 检查文件夹是否存在if not os.path.exists(folder_a):print(f"错误: 预测文件夹 '{folder_a}' 不存在!")returnif not os.path.exists(folder_b):print(f"错误: 真值文件夹 '{folder_b}' 不存在!")return# 计算各IoU阈值下的mAPresults = calculate_map_for_folders(folder_a, folder_b, target_classes)# 提取主要指标mAP50 = results.get('mAP50', 0.0)mAP75 = results.get('mAP75', 0.0)mAP90 = results.get('mAP90', 0.0)# 计算综合得分comprehensive_score = 0.6 * mAP50 + 0.2 * mAP75 + 0.2 * mAP90# 计算所有分类的总得分total_score_50 = sum(results.get('class_aps_50', {}).values())total_score_75 = sum(results.get('class_aps_75', {}).values())total_score_90 = sum(results.get('class_aps_90', {}).values())# 输出最终结果print("\n" + "="*60)print("最终结果汇总")print("="*60)print(f"mAP@0.5:   {mAP50:.4f}")print(f"mAP@0.75:  {mAP75:.4f}")print(f"mAP@0.9:   {mAP90:.4f}")print("-" * 40)print(f"综合得分:  {comprehensive_score:.4f}")print(f"权重配置: 0.6*mAP50 + 0.2*mAP75 + 0.2*mAP90")print("-" * 40)print("所有分类总得分:")print(f"  总AP@0.5:  {total_score_50:.4f}")print(f"  总AP@0.75: {total_score_75:.4f}")print(f"  总AP@0.9:  {total_score_90:.4f}")print(f"  加权总得分: {0.6*total_score_50 + 0.2*total_score_75 + 0.2*total_score_90:.4f}")print("="*60)#     # 详细的每类别结果
#     print("\n各类别详细结果:")
#     for class_name in target_classes:
#         print(f"\n{class_name}:")
#         for thresh in [50, 75, 90]:
#             key = f'class_aps_{thresh}'
#             if key in results and class_name in results[key]:
#                 ap = results[key][class_name]
#                 print(f"  AP@0.{thresh//10 if thresh != 50 else '5'}: {ap:.4f}")# 保存结果到文件output_results = {'mAP50': mAP50,'mAP75': mAP75,'mAP90': mAP90,'comprehensive_score': comprehensive_score,'total_scores': {'total_AP50': total_score_50,'total_AP75': total_score_75,'total_AP90': total_score_90,'weighted_total_score': 0.6*total_score_50 + 0.2*total_score_75 + 0.2*total_score_90},'target_classes': target_classes,'weights': {'mAP50': 0.6, 'mAP75': 0.2, 'mAP90': 0.2},'class_details': {'mAP50_classes': results.get('class_aps_50', {}),'mAP75_classes': results.get('class_aps_75', {}),'mAP90_classes': results.get('class_aps_90', {})}}import jsonwith open('xml_map_results.json', 'w', encoding='utf-8') as f:json.dump(output_results, f, indent=2, ensure_ascii=False)print(f"\n详细结果已保存到: xml_map_results.json")if __name__ == "__main__":main()
http://www.dtcms.com/a/298853.html

相关文章:

  • KNN 算法进阶:从基础到优化的深度解析
  • GaussDB 数据库架构师修炼(九) 逻辑备份实操
  • 动态规划Day1学习心得
  • JavaWeb项目(纯Servlet+JSP+前端三大件)入门(从0开始)
  • JavaSE-图书信息管理系统
  • jwt 在net9.0中做身份认证
  • 2507C++,窗口勾挂事件
  • IPv6,你开始使用了吗?
  • MATLAB 设置默认启动路径为上次关闭路径的方法
  • Linux C : 指针
  • ZYNQ芯片,SPI驱动开发自学全解析个人笔记【FPGA】【赛灵思】
  • 您的需求已被采纳
  • 【51单片机简易红绿灯计数延时】2022-9-23
  • AIStarter平台亮点解析:从ComfyUI项目上架到一键运行的完整指南
  • I/O多路复用机制中触发机制详细解析
  • 数字化转型-AI落地金字塔法则
  • 【补题】Codeforces Round 735 (Div. 2) B. Cobb
  • 卡尔曼滤波器噪声方差设置对性能影响的仿真研究
  • 【LeetCode刷题指南】--有效的括号
  • K-近邻算法中的近似误差和估计误差
  • Win11批量部署神器winget
  • SQL基础⑯ | MySQL8新特性篇
  • JDK8保姆级安装教程
  • 新房装修是中央空调还是壁挂空调好?
  • 安卓上的迷之K_1171477665
  • 扒网站工具 HTTrack Website Copier
  • Netty中future和promise用法和区别
  • HashMap的线程安全性 vs ConcurrentHashMap
  • 从零开发Java坦克大战:架构设计与难点突破 (下)
  • Java测试题(上)