cocodataset数据集可视化
CocoDetection
是 PyTorch 的 torchvision.datasets
模块中用于加载 COCO 格式目标检测数据集的类。COCO(Common Objects in Context)是一个广泛使用的计算机视觉数据集,包含图像以及对应的目标检测、分割等标注信息。
一、功能
CocoDetection类用于加载COCO格式的目标检测数据集,主要功能包括:
1.加载图像和对应的标注文件(JSON文件)
2.解析标注信息,包括目标类别、边界框(bbox)、分割信息等
3.支持数据增强和转换(通过transform参数)
4.提供迭代接口,方便遍历数据集中的样本。
二、参数
CocoDetection
类的构造函数参数如下:
三、返回值
每个样本返回一个元组(image, target):
image: PIL图像对象或经过transform处理后的张量
target: 一个列表,每个元素是一个字典,包含以下字段
· segmentation: 多边形信息
· area: 目标区域面积
· bbox: 目标边界框,格式为[x_min, y_min, width, height]
· image_id: 对应图像的唯一id
· iscrowd: 是否为拥挤目标(0或1)
· id: 该图像的第几个检测框
四、使用示例及可视化
"""""
@Author : jiguotong
@Contact : 1776220977@qq.com
@site :
-----------------------------------------------
@Time : 2025/3/5
@Description: 对cocodataset风格的图像及标注进行可视化
"""""
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torchvision.datasets as datasets
"""获得3个(0, 1)随机数
@return: [tuple]3个(0, 1)随机数组成的元组
"""
def generate_random_color():
return (random.random(), random.random(), random.random())
""" 将cocodataset解析的数据集可视化并保存图像
@param image: [PIL.Image.Image]需要绘制的底图
@param target: [Dict]CocoDetection解析出的annotation字典
@param output_path: [str]输出目录
@return NULL
"""
def visualize_and_save(image, target, output_path):
# 创建画布
_, ax = plt.subplots(1)
ax.imshow(image)
# 绘制每个目标的边界框和类别名称
for annotation in target:
box = annotation['bbox']
class_id = annotation['category_id']
segmentation = annotation['segmentation']
color = generate_random_color() # 为每个目标生成随机颜色
# 1.绘制多边形
for poly in segmentation:
# 将多边形点转换为 (N, 2) 的数组
poly = np.array(poly).reshape(-1, 2)
# 绘制多边形
polygon = patches.Polygon(poly, edgecolor=color, facecolor=color + (0.7,), linewidth=2)
ax.add_patch(polygon)
# 2.提取边界框坐标 (x_min, y_min, width, height)
x, y, width, height = box
rect = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
pass # for
# 保存可视化结果
plt.axis('off') # 关闭坐标轴
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
if __name__=="__main__":
# 数据集路径
root = "./data/Images" # 图像文件夹路径
annFile = "./data/annot/val.json" # 标注文件路径
# 加载 COCO 数据集
coco_dataset = datasets.CocoDetection(root=root, annFile=annFile)
# 创建保存可视化结果的文件夹
output_dir = "./coco_visualization"
os.makedirs(output_dir, exist_ok=True)
# 遍历数据集并保存可视化结果
for idx in range(len(coco_dataset)):
# 获取图像和标注
image, target = coco_dataset[idx]
# 获取图像文件名
image_id = target[0]['image_id']
image_info = coco_dataset.coco.loadImgs(image_id)[0]
image_filename = image_info['file_name']
# 构建输出路径
output_path = os.path.join(output_dir, f"visualized_{image_filename}")
# 可视化并保存
visualize_and_save(image, target, output_path)
print(f"Saved visualization for {image_filename} to {output_path}")
pass