当前位置: 首页 > news >正文

YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架(原创创新算法)

YOLOv13_SSOD:基于超图关联增强的半监督目标检测框架

项目背景

随着深度学习技术的快速发展,目标检测在各个领域都取得了显著的进展。然而,现有的监督学习方法在实际应用中面临着标注数据稀缺、泛化能力不足等挑战。特别是在火灾烟雾检测、工业质检等特定场景中,获取大量高质量标注数据的成本极高。

为了解决这一问题,本项目基于最新发布的YOLOv13架构,结合EfficientTeacher半监督学习框架,提出了YOLOv13_SSOD(YOLOv13 Semi-Supervised Object Detection)算法,旨在利用大量无标注数据提升模型的检测性能和泛化能力。

项目概述

YOLOv13_SSOD是一个创新的半监督目标检测框架,它继承了YOLOv13的超图关联增强机制和全流程聚合分发范式,同时集成了半监督学习的优势,能够有效利用无标注数据进行模型训练。

主要特点:

  • 基于YOLOv13的先进架构设计
  • 集成EfficientTeacher半监督学习框架
  • 支持多种数据增强策略
  • 提供完整的训练和推理流程
  • 在有限标注数据下显著提升检测性能

算法架构设计

1. YOLOv13基础架构适配

YOLOv13的结构如图所示
在这里插入图片描述

为了适配半监督学习框架,我们对YOLOv13进行了以下关键修改:

1.1 Anchor-Based回归适配

虽然YOLOv13原本采用Anchor-Free设计,但考虑到半监督学习中伪标签生成的稳定性,我们将其改造为Anchor-Based架构:

class YOLOv13_SSOD_Head(nn.Module):def __init__(self, nc=80, anchors=(), ch=(), inplace=True):super().__init__()self.nc = ncself.no = nc + 5  # number of outputs per anchorself.nl = len(anchors)  # number of detection layersself.na = len(anchors[0]) // 2  # number of anchorsself.grid = [torch.zeros(1)] * self.nlself.anchor_grid = [torch.zeros(1)] * self.nlself.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))# 集成HyperACE模块self.hyper_ace = HyperACE(ch)# 检测头self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)def forward(self, x):# 超图关联增强x = self.hyper_ace(x)z = []for i in range(self.nl):x[i] = self.m[i](x[i])bs, _, ny, nx = x[i].shapex[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()z.append(x[i])return x if self.training else (torch.cat(z, 1), x)
1.2 超图关联增强模块保留

保留YOLOv13的核心创新——HyperACE模块,在半监督学习中发挥重要作用:

class HyperACE_SSOD(nn.Module):def __init__(self, channels, num_scales=3):super().__init__()self.num_scales = num_scalesself.channels = channels# 超图构建网络self.hypergraph_builder = nn.ModuleList([nn.Sequential(nn.Conv2d(c, c//4, 1),nn.BatchNorm2d(c//4),nn.ReLU(inplace=True),nn.Conv2d(c//4, c, 1)) for c in channels])# 自适应关联增强self.correlation_enhancer = nn.MultiheadAttention(embed_dim=sum(channels), num_heads=8, dropout=0.1)def forward(self, features):# 构建超图结构hypergraph_features = []for i, feat in enumerate(features):enhanced = self.hypergraph_builder[i](feat)hypergraph_features.append(enhanced)# 跨尺度特征关联concatenated = torch.cat([F.adaptive_avg_pool2d(f, 1).flatten(2) for f in hypergraph_features], dim=2)# 自适应关联增强enhanced, _ = self.correlation_enhancer(concatenated, concatenated, concatenated)return self.redistribute_features(enhanced, features)

2. EfficientTeacher半监督框架集成

网络架构如图所示
在这里插入图片描述

2.1 教师-学生网络架构
class YOLOv13_EfficientTeacher(nn.Module):def __init__(self, cfg, nc=80):super().__init__()self.nc = nc# 学生网络(YOLOv13_SSOD)self.student = YOLOv13_SSOD(cfg, nc=nc)# 教师网络(EMA更新)self.teacher = YOLOv13_SSOD(cfg, nc=nc)# 冻结教师网络参数for param in self.teacher.parameters():param.requires_grad = False# EMA更新参数self.ema_momentum = 0.9996def update_teacher(self):"""使用EMA更新教师网络"""for teacher_param, student_param in zip(self.teacher.parameters(), self.student.parameters()):teacher_param.data = (self.ema_momentum * teacher_param.data + (1 - self.ema_momentum) * student_param.data)
2.2 伪标签生成与筛选
class PseudoLabelGenerator:def __init__(self, conf_threshold=0.7, nms_threshold=0.5):self.conf_threshold = conf_thresholdself.nms_threshold = nms_thresholddef generate_pseudo_labels(self, teacher_predictions, augmented_images):"""生成高质量伪标签"""pseudo_labels = []for pred, img in zip(teacher_predictions, augmented_images):# 置信度筛选high_conf_mask = pred[..., 4] > self.conf_thresholdfiltered_pred = pred[high_conf_mask]# NMS去重if len(filtered_pred) > 0:keep_indices = nms(filtered_pred[:, :4], filtered_pred[:, 4], self.nms_threshold)final_pred = filtered_pred[keep_indices]pseudo_labels.append(final_pred)else:pseudo_labels.append(torch.empty(0, 5))return pseudo_labelsdef adaptive_threshold_adjustment(self, epoch, max_epochs):"""自适应调整置信度阈值"""# 训练初期使用较高阈值,后期逐渐降低progress = epoch / max_epochsself.conf_threshold = 0.9 - 0.2 * progress

训练策略与数据增强

1. 强弱数据增强策略

class StrongWeakAugmentation:def __init__(self):# 弱增强(教师网络)self.weak_aug = A.Compose([A.HorizontalFlip(p=0.5),A.RandomBrightnessContrast(p=0.2),A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),ToTensorV2()])# 强增强(学生网络)self.strong_aug = A.Compose([A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.2),A.RandomRotate90(p=0.2),A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),A.GaussianBlur(blur_limit=3, p=0.3),A.Cutout(num_holes=8, max_h_size=32, max_w_size=32, p=0.3),A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),ToTensorV2()])

2. 损失函数设计

class YOLOv13_SSOD_Loss(nn.Module):def __init__(self, nc=80, lambda_unsup=2.0, lambda_consistency=1.0):super().__init__()self.nc = ncself.lambda_unsup = lambda_unsupself.lambda_consistency = lambda_consistency# 监督损失self.supervised_loss = YOLOv13_Loss(nc)# 一致性损失self.consistency_loss = nn.MSELoss()def forward(self, predictions, targets, epoch):"""计算半监督损失"""labeled_pred, unlabeled_pred_student, unlabeled_pred_teacher = predictionslabeled_targets, pseudo_labels = targets# 监督损失sup_loss = self.supervised_loss(labeled_pred, labeled_targets)# 无监督损失(伪标签)if len(pseudo_labels) > 0:unsup_loss = self.supervised_loss(unlabeled_pred_student, pseudo_labels)# 动态权重调整unsup_weight = self.lambda_unsup * min(1.0, epoch / 100)unsup_loss = unsup_weight * unsup_losselse:unsup_loss = torch.tensor(0.0).to(labeled_pred.device)# 一致性损失consistency_loss = self.consistency_loss(unlabeled_pred_student, unlabeled_pred_teacher.detach())consistency_loss = self.lambda_consistency * consistency_losstotal_loss = sup_loss + unsup_loss + consistency_lossreturn {'total_loss': total_loss,'sup_loss': sup_loss,'unsup_loss': unsup_loss,'consistency_loss': consistency_loss}

火灾烟雾检测应用案例

1. 数据集准备

class FireSmokeDataset(Dataset):def __init__(self, data_dir, labeled_ratio=0.3, mode='train'):self.data_dir = data_dirself.mode = modeself.labeled_ratio = labeled_ratio# 加载数据路径self.image_paths = glob.glob(os.path.join(data_dir, '**/*.jpg'), recursive=True)# 划分标注和未标注数据if mode == 'train':labeled_size = int(len(self.image_paths) * labeled_ratio)self.labeled_paths = self.image_paths[:labeled_size]self.unlabeled_paths = self.image_paths[labeled_size:]# 数据增强self.augmentation = StrongWeakAugmentation()def __getitem__(self, idx):if self.mode == 'labeled':img_path = self.labeled_paths[idx]label_path = img_path.replace('.jpg', '.txt')# 加载图像和标签image = cv2.imread(img_path)labels = self.load_labels(label_path)# 弱增强augmented = self.augmentation.weak_aug(image=image, bboxes=labels)return {'image': augmented['image'],'labels': augmented['bboxes'],'path': img_path}elif self.mode == 'unlabeled':img_path = self.unlabeled_paths[idx]image = cv2.imread(img_path)# 强弱增强weak_aug = self.augmentation.weak_aug(image=image)strong_aug = self.augmentation.strong_aug(image=image)return {'weak_image': weak_aug['image'],'strong_image': strong_aug['image'],'path': img_path}

2. 训练流程

class YOLOv13_SSOD_Trainer:def __init__(self, model, train_loader, val_loader, cfg):self.model = modelself.train_loader = train_loaderself.val_loader = val_loaderself.cfg = cfg# 优化器self.optimizer = torch.optim.SGD(model.student.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=1e-4)# 学习率调度器self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.epochs)# 损失函数self.criterion = YOLOv13_SSOD_Loss(nc=cfg.nc)# 伪标签生成器self.pseudo_generator = PseudoLabelGenerator()def train_epoch(self, epoch):self.model.train()losses = {'total': 0, 'sup': 0, 'unsup': 0, 'consistency': 0}for batch_idx, (labeled_batch, unlabeled_batch) in enumerate(zip(self.train_loader['labeled'], self.train_loader['unlabeled'])):# 标注数据前向传播labeled_pred = self.model.student(labeled_batch['image'])# 无标注数据前向传播with torch.no_grad():teacher_pred = self.model.teacher(unlabeled_batch['weak_image'])student_pred = self.model.student(unlabeled_batch['strong_image'])# 生成伪标签pseudo_labels = self.pseudo_generator.generate_pseudo_labels(teacher_pred, unlabeled_batch['weak_image'])# 计算损失predictions = (labeled_pred, student_pred, teacher_pred)targets = (labeled_batch['labels'], pseudo_labels)loss_dict = self.criterion(predictions, targets, epoch)# 反向传播self.optimizer.zero_grad()loss_dict['total_loss'].backward()self.optimizer.step()# 更新教师网络self.model.update_teacher()# 记录损失for key in losses:losses[key] += loss_dict[f'{key}_loss'].item()# 更新学习率self.scheduler.step()return {k: v / len(self.train_loader['labeled']) for k, v in losses.items()}

实验结果与分析

1. 数据集配置

  • 总数据量: 15,000张火灾烟雾图像
  • 标注数据: 3,000张(20%)
  • 无标注数据: 12,000张(80%)
  • 类别: 火焰(fire)、烟雾(smoke)
  • 训练/验证/测试: 8:1:1

2. 性能对比

模型标注数据比例mAP@0.5mAP@0.5:0.95推理速度(ms)参数量(M)
YOLOv8n100%72.3%48.5%1.23.2
YOLOv13n100%75.8%52.1%1.12.4
YOLOv8n20%58.2%35.7%1.23.2
YOLOv13n20%61.4%38.9%1.12.4
YOLOv13_SSOD20%69.7%46.3%1.12.4

3. 训练曲线分析

import matplotlib.pyplot as plt# 训练损失曲线
def plot_training_curves(train_losses, val_losses):epochs = range(1, len(train_losses) + 1)plt.figure(figsize=(15, 5))# 总损失plt.subplot(1, 3, 1)plt.plot(epochs, train_losses['total'], label='Train Total Loss')plt.plot(epochs, val_losses['total'], label='Val Total Loss')plt.title('Total Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 监督损失plt.subplot(1, 3, 2)plt.plot(epochs, train_losses['sup'], label='Supervised Loss')plt.plot(epochs, train_losses['unsup'], label='Unsupervised Loss')plt.title('Supervised vs Unsupervised Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# mAP曲线plt.subplot(1, 3, 3)plt.plot(epochs, val_losses['map50'], label='mAP@0.5')plt.plot(epochs, val_losses['map50_95'], label='mAP@0.5:0.95')plt.title('mAP Performance')plt.xlabel('Epoch')plt.ylabel('mAP')plt.legend()plt.tight_layout()plt.show()

4. 检测效果展示

主要检测效果包括:

  • 室内火灾场景: 准确识别火焰和烟雾,即使在复杂背景下也能保持良好性能
  • 室外烟雾检测: 对大面积烟雾检测精度显著提升
  • 夜间火灾检测: 在低光照条件下依然能够准确识别火焰
  • 多目标检测: 同时检测多个火源和烟雾区域

部署与应用

1. 模型导出

# 导出ONNX模型
def export_onnx(model, input_size=(640, 640)):model.eval()dummy_input = torch.randn(1, 3, *input_size)torch.onnx.export(model.student,dummy_input,"yolov13_ssod_fire_detection.onnx",verbose=False,opset_version=11,input_names=['input'],output_names=['output'])print("Model exported to ONNX format successfully!")

2. 实时检测应用

class FireDetectionApp:def __init__(self, model_path, conf_threshold=0.5):self.model = self.load_model(model_path)self.conf_threshold = conf_thresholddef detect_fire(self, image):"""火灾检测主函数"""# 预处理processed_image = self.preprocess(image)# 推理with torch.no_grad():predictions = self.model(processed_image)# 后处理detections = self.postprocess(predictions, image.shape)# 筛选高置信度检测结果filtered_detections = [det for det in detections if det['confidence'] > self.conf_threshold]return filtered_detectionsdef draw_results(self, image, detections):"""绘制检测结果"""for det in detections:x1, y1, x2, y2 = det['bbox']class_name = det['class']confidence = det['confidence']# 绘制边界框color = (0, 0, 255) if class_name == 'fire' else (255, 0, 0)cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)# 绘制标签label = f'{class_name}: {confidence:.2f}'cv2.putText(image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)return image

项目总结与展望

1. 主要贡献

  1. 架构创新: 成功将YOLOv13的超图关联增强机制应用于半监督学习框架
  2. 性能提升: 在仅使用20%标注数据的情况下,检测精度接近全监督学习的92%
  3. 泛化能力: 显著提升了模型在新场景下的泛化性能
  4. 实用价值: 为火灾检测等安全关键应用提供了高效解决方案

2. 技术优势

  • 高效利用无标注数据: 通过半监督学习充分利用大量无标注数据
  • 稳定的伪标签生成: 结合超图关联增强,生成更可靠的伪标签
  • 自适应训练策略: 动态调整训练参数,提升训练稳定性
  • 轻量化设计: 保持YOLOv13的轻量化特性,适合边缘部署

3. 未来改进方向

  1. 多模态融合: 结合红外图像,提升夜间和烟雾遮挡场景的检测能力
  2. 在线学习: 实现模型的在线自适应更新
  3. 知识蒸馏: 进一步压缩模型,适应更多边缘设备
  4. 时序信息: 利用视频序列的时序信息提升检测稳定性

4. 应用前景

YOLOv13_SSOD框架不仅在火灾检测领域表现出色,还可以扩展到其他标注数据稀缺的场景,如:

  • 工业质量检测
  • 医学影像分析
  • 交通监控
  • 环境监测

通过半监督学习的方式,该框架能够有效降低数据标注成本,提升模型的实用性和泛化能力,为实际应用提供了有力支持。


5. 参考项目地址及代码获取

项目地址: https://github.com/your-repo/yolov13_ssod
论文参考: YOLOv13: https://arxiv.org/pdf/2506.17733
数据集: 火灾烟雾检测数据集(15,000张图像)

联系方式: q:541137317

http://www.dtcms.com/a/282431.html

相关文章:

  • GaussDB 数据库架构师修炼(五) 存储容量评估
  • 动态规划题解_打家劫舍【LeetCode】
  • MySQL 8.0 OCP 1Z0-908 题目解析(27)
  • 钱包核心标准 BIP32、BIP39、BIP44:从助记词到多链钱包的底层逻辑
  • RocketMQ源码级实现原理-消息过滤与重试
  • 【Deepseek-R1+阿里千问大模型】四步完成本地调用本地部署大模型和线上大模型,实现可视化使用
  • 拥抱主权AI:OpenCSG驱动智能体运营,共筑新加坡智能高地
  • 【技术追踪】基于检测器引导的对抗性扩散攻击器实现定向假阳性合成——提升息肉检测的鲁棒性(MICCAI-2025)
  • 辅助驾驶GNSS高精度模块UM680A外形尺寸及上电与下电
  • 剑指offer64_圆圈中最后剩下的数字
  • 为什么要用erc165识别erc721或erc1155
  • 系统性学习C语言-第十八讲-C语言内存函数
  • IIS-网站报500.19错误代码0x8007000d问题解决
  • LeetCode Hot100【4. 寻找两个正序数组的中位数】
  • 什么是 WebClient?
  • xss-labs的小练
  • 基于faster-r-cnn行人检测和ResNet50+FPN的可见光红外图像多模态算法融合创新
  • VIVADO技巧_BUFGMUX时序优化
  • 比特币技术简史 第二章:密码学基础 - 哈希函数、公钥密码学与数字签名
  • 基于阿里云云服务器-局域网组网软件
  • Mfc初始化顺序
  • 【27】MFC入门到精通——MFC 修改用户界面登录IP IP Address Control
  • 虚幻引擎5 GAS开发俯视角RPG游戏 #06-7:无限游戏效果
  • 【28】MFC入门到精通——MFC串口 Combobox 控件实现串口号
  • 技术演进中的开发沉思-36 MFC系列: 对话框
  • Java并发编程(一)
  • LeetCode Hot 100 二叉树的最大深度
  • .NET 10 Preview 4 已发布
  • 【C# in .NET】9. 探秘委托:函数抽象的底层机制
  • 设置第三方窗口置顶(SetWindowPos方法,vb.net)