当前位置: 首页 > news >正文

【即插即用涨点模块】DSConv动态蛇形卷积:自适应聚焦细长弯曲的局部结构特征,助力分割高效提点【附源码+注释】

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于深度学习的行人跌倒检测系统】
9.【基于深度学习的PCB板缺陷检测系统】10.【基于深度学习的生活垃圾分类目标检测系统】
11.【基于深度学习的安全帽目标检测系统】12.【基于深度学习的120种犬类检测与识别系统】
13.【基于深度学习的路面坑洞检测系统】14.【基于深度学习的火焰烟雾检测系统】
15.【基于深度学习的钢材表面缺陷检测系统】16.【基于深度学习的舰船目标分类检测系统】
17.【基于深度学习的西红柿成熟度检测系统】18.【基于深度学习的血细胞检测与计数系统】
19.【基于深度学习的吸烟/抽烟行为检测系统】20.【基于深度学习的水稻害虫检测与识别系统】
21.【基于深度学习的高精度车辆行人检测与计数系统】22.【基于深度学习的路面标志线检测与识别系统】
23.【基于深度学习的智能小麦害虫检测识别系统】24.【基于深度学习的智能玉米害虫检测识别系统】
25.【基于深度学习的200种鸟类智能检测与识别系统】26.【基于深度学习的45种交通标志智能检测与识别系统】
27.【基于深度学习的人脸面部表情识别系统】28.【基于深度学习的苹果叶片病害智能诊断系统】
29.【基于深度学习的智能肺炎诊断系统】30.【基于深度学习的葡萄簇目标检测系统】
31.【基于深度学习的100种中草药智能识别系统】32.【基于深度学习的102种花卉智能识别系统】
33.【基于深度学习的100种蝴蝶智能识别系统】34.【基于深度学习的水稻叶片病害智能诊断系统】
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统】36.【基于深度学习的智能草莓病害检测与分割系统】
37.【基于深度学习的复杂场景下船舶目标检测系统】38.【基于深度学习的农作物幼苗与杂草检测系统】
39.【基于深度学习的智能道路裂缝检测与分析系统】40.【基于深度学习的葡萄病害智能诊断与防治系统】
41.【基于深度学习的遥感地理空间物体检测系统】42.【基于深度学习的无人机视角地面物体检测系统】
43.【基于深度学习的木薯病害智能诊断与防治系统】44.【基于深度学习的野外火焰烟雾检测系统】
45.【基于深度学习的脑肿瘤智能检测系统】46.【基于深度学习的玉米叶片病害智能诊断与防治系统】
47.【基于深度学习的橙子病害智能诊断与防治系统】48.【基于深度学习的车辆检测追踪与流量计数系统】
49.【基于深度学习的行人检测追踪与双向流量计数系统】50.【基于深度学习的反光衣检测与预警系统】
51.【基于深度学习的危险区域人员闯入检测与报警系统】52.【基于深度学习的高密度人脸智能检测与统计系统】
53.【基于深度学习的CT扫描图像肾结石智能检测系统】54.【基于深度学习的水果智能检测系统】
55.【基于深度学习的水果质量好坏智能检测系统】56.【基于深度学习的蔬菜目标检测与识别系统】
57.【基于深度学习的非机动车驾驶员头盔检测系统】58.【太基于深度学习的阳能电池板检测与分析系统】
59.【基于深度学习的工业螺栓螺母检测】60.【基于深度学习的金属焊缝缺陷检测系统】
61.【基于深度学习的链条缺陷检测与识别系统】62.【基于深度学习的交通信号灯检测识别】
63.【基于深度学习的草莓成熟度检测与识别系统】64.【基于深度学习的水下海生物检测识别系统】
65.【基于深度学习的道路交通事故检测识别系统】66.【基于深度学习的安检X光危险品检测与识别系统】
67.【基于深度学习的农作物类别检测与识别系统】68.【基于深度学习的危险驾驶行为检测识别系统】
69.【基于深度学习的维修工具检测识别系统】70.【基于深度学习的维修工具检测识别系统】
71.【基于深度学习的建筑墙面损伤检测系统】72.【基于深度学习的煤矿传送带异物检测系统】
73.【基于深度学习的老鼠智能检测系统】74.【基于深度学习的水面垃圾智能检测识别系统】
75.【基于深度学习的遥感视角船只智能检测系统】76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统】
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】
79.【基于深度学习的果园苹果检测与计数系统】80.【基于深度学习的半导体芯片缺陷检测系统】
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统】82.【基于深度学习的运动鞋品牌检测与识别系统】
83.【基于深度学习的苹果叶片病害检测识别系统】84.【基于深度学习的医学X光骨折检测与语音提示系统】
85.【基于深度学习的遥感视角农田检测与分割系统】86.【基于深度学习的运动品牌LOGO检测与识别系统】
87.【基于深度学习的电瓶车进电梯检测与语音提示系统】88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统】
89.【基于深度学习的医学CT图像肺结节智能检测与语音提示系统】90.【基于深度学习的舌苔舌象检测识别与诊断系统】
91.【基于深度学习的蛀牙智能检测与语音提示系统】92.【基于深度学习的皮肤癌智能检测与语音提示系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

  • 论文信息
  • 摘要
  • 方法
    • 1. 动态蛇形卷积(DSConv)
      • 2. 多视角特征融合
      • 3. 拓扑连续性约束损失
    • 创新点
  • DSConv的作用
  • 总结
  • DSConv源码+注释

论文信息

在这里插入图片描述

论文地址:https://arxiv.org/abs/2307.08388
源码地址:https://github.com/YaoleiQi/DSCNet

摘要

在这里插入图片描述

本文提出了一种针对管状结构(如血管、道路)分割的新框架DSCNet,通过融合拓扑几何约束知识,在特征提取、特征融合和损失约束三阶段增强模型感知能力。主要贡献包括:

  1. 动态蛇形卷积(DSConv)​​:自适应聚焦细长弯曲的局部结构特征(对应挑战1:脆弱局部结构)
  2. 多视角特征融合策略​:通过多形态卷积核模板保留不同全局形态的关键信息(对应挑战2:复杂全局形态)
  3. 拓扑连续性约束损失(TCLoss)​​:基于持续同调理论约束分割结果的拓扑连续性
    实验表明,DSCNet在2D/3D数据集上均优于现有方法。
    在这里插入图片描述

方法

1. 动态蛇形卷积(DSConv)

  • 核心思想​:将标准卷积核改造为蛇形路径,通过迭代式偏移累积适应管状结构形态。相比可变形卷积,DSConv通过轴向约束避免感知区域漂移。
  • 实现细节​:
    • X/Y轴方向分别进行坐标线性化(图展示9×9核的变形过程)
    • 使用双线性插值处理分数坐标
      在这里插入图片描述

2. 多视角特征融合

  • 策略流程​:
    1. 生成多组DSConv模板提取不同视角特征
    2. 引入随机丢弃策略减少冗余噪声
    3. 保留最优模板组合用于测试阶段
      在这里插入图片描述

3. 拓扑连续性约束损失

在这里插入图片描述

  • 持续同调应用​:
    • 构建持久图(PD)记录拓扑特征(0维/1维同调)的生存周期
    • 采用Hausdorff距离度量预测与真值的拓扑差异
  • 联合损失​: L T C = L C E + ∑ d H ∗ \mathcal{L}_{TC} = \mathcal{L}_{CE} + \sum d_H^* LTC=LCE+dH

创新点

创新模块解决痛点技术亮点
DSConv细小局部结构易丢失蛇形路径约束+迭代偏移
多视角融合复杂形态导致过拟合随机丢弃的模板组合优化
TCLoss分割断裂问题持续同调+异常拓扑惩罚

DSConv的作用

在这里插入图片描述

  1. 形态适应性​:

    • 相比可变形卷积(易偏离目标),DSConv严格沿管状结构延伸
    • 通过 Σ Δ \Sigma \Delta ΣΔ约束实现线性感知
  2. 特征增强​:

    • 热力图显示对管状区域的高响应
    • 在DRIVE数据集上Dice提升1.23%
  3. 跨维度扩展​:

    • 支持3D数据(冠状动脉CTA)的分割,HD降低19.3%

总结

  1. 贡献总结​:

    • 提出首个融合几何形态与拓扑约束的管状结构分割框架
    • DSConv和TCLoss可迁移至其他网络(表1中UNet+TCLoss提升 β 0 \beta_0 β0误差)
  2. 实验验证​:

    • 2D数据集:DRIVE道路分割OF指标提升6%(表2)
    • 3D数据集:冠状动脉RCA分支HD降至5.787(表2)
  3. 未来方向​:

    • 探索其他形态目标的适应性(如神经元分支)
    • 结合更多拓扑不变量(如Betti数高阶约束)

DSConv源码+注释

# -*- coding: utf-8 -*-
import os
import torch
import numpy as np
from torch import nn
import warnings
# 忽略警告信息以清洁输出
warnings.filterwarnings("ignore")"""
This code is mainly the deformation process of our DSConv
"""class DSConv(nn.Module):def __init__(self, in_ch, out_ch, kernel_size, extend_scope, morph, if_offset, device):"""初始化动态蛇形卷积层(Dynamic Snake Convolution)。:param in_ch: 输入通道数。:param out_ch: 输出通道数。:param kernel_size: 卷积核大小。:param extend_scope: 动态调整的最大位移范围。:param morph: 核变形方向,'0' 为沿x轴,'1' 为沿y轴。:param if_offset: 是否启用动态形变。:param device: 运算设备('cuda' 或 'cpu')。"""super(DSConv, self).__init__()# 可学习偏移量以实现动态形变self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)self.bn = nn.BatchNorm2d(2 * kernel_size)self.kernel_size = kernel_size# 针对x轴和y轴形变的卷积操作self.dsc_conv_x = nn.Conv2d(in_ch, out_ch, kernel_size=(kernel_size, 1), stride=(kernel_size, 1), padding=0)self.dsc_conv_y = nn.Conv2d(in_ch, out_ch, kernel_size=(1, kernel_size), stride=(1, kernel_size), padding=0)# 分组归一化和ReLU激活函数self.gn = nn.GroupNorm(out_ch // 4, out_ch)self.relu = nn.ReLU(inplace=True)# 配置设置self.extend_scope = extend_scopeself.morph = morphself.if_offset = if_offsetself.device = devicedef forward(self, f):"""DSConv层的前向传播。:param f: 输入特征图。:return: 变换后的特征图。"""# 计算偏移量并应用批归一化offset = self.offset_conv(f)offset = self.bn(offset)offset = torch.tanh(offset)  # 将偏移量规范化到[-1, 1]范围内# 根据输入形状和设置准备可变形卷积dsc = DSC(f.shape, self.kernel_size, self.extend_scope, self.morph, self.device)deformed_feature = dsc.deform_conv(f, offset, self.if_offset)# 根据形变方向应用卷积if self.morph == 0:x = self.dsc_conv_x(deformed_feature)else:x = self.dsc_conv_y(deformed_feature)# 归一化和激活输出x = self.gn(x)x = self.relu(x)return x# 辅助类处理形变过程
class DSC(object):def __init__(self, input_shape, kernel_size, extend_scope, morph, device):"""初始化用于处理形变的DSC对象。:param input_shape: 输入张量的形状。:param kernel_size: 卷积核大小。:param extend_scope: 形变范围。:param morph: 形变方向。:param device: 运算设备。"""self.num_points = kernel_sizeself.width = input_shape[2]self.height = input_shape[3]self.morph = morphself.device = deviceself.extend_scope = extend_scopeself.num_batch = input_shape[0]self.num_channels = input_shape[1]def _coordinate_map_3D(self, offset, if_offset):"""根据给定的偏移量生成3D坐标映射。:param offset: 卷积层学习到的偏移量。:param if_offset: 是否应用形变的标志。:return: 用于形变的坐标映射。"""# 分解偏移量为x轴和y轴分量y_offset, x_offset = torch.split(offset, self.num_points, dim=1)# 生成y和x的基础网格坐标y_center = torch.arange(0, self.width).repeat(self.height).view(self.height, self.width).permute(1, 0). \repeat(self.num_points, 1, 1).float().unsqueeze(0)x_center = torch.arange(0, self.height).repeat(self.width).view(self.width, self.height).permute(0, 1). \repeat(self.num_points, 1, 1).float().unsqueeze(0)if self.morph == 0:# 沿x轴形变的情况# 初始化网格的y坐标,所有y坐标都为0y = torch.linspace(0, 0, 1)# 初始化网格的x坐标,均匀分布在-kernel_size/2到kernel_size/2之间x = torch.linspace(-int(self.num_points // 2), int(self.num_points // 2), int(self.num_points))# 生成网格坐标y, x = torch.meshgrid(y, x)y_spread = y.reshape(-1, 1)x_spread = x.reshape(-1, 1)# 重复网格坐标以覆盖整个特征图的宽度和高度y_grid = y_spread.repeat([1, self.width * self.height])y_grid = y_grid.reshape([self.num_points, self.width, self.height])y_grid = y_grid.unsqueeze(0)  # 增加批次维度x_grid = x_spread.repeat([1, self.width * self.height])x_grid = x_grid.reshape([self.num_points, self.width, self.height])x_grid = x_grid.unsqueeze(0)  # 增加批次维度# 将基准中心坐标与网格坐标相加以生成新的形变坐标y_new = y_center + y_gridx_new = x_center + x_grid# 复制到每个批次y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(self.device)x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(self.device)# 根据偏移量调整y坐标y_offset_new = y_offset.detach().clone()if if_offset:y_offset = y_offset.permute(1, 0, 2, 3)y_offset_new = y_offset_new.permute(1, 0, 2, 3)center = int(self.num_points // 2)# 中心位置保持不变,其他位置根据偏移量动态调整y_offset_new[center] = 0for index in range(1, center):y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(self.device)y_new = y_new.add(y_offset_new.mul(self.extend_scope))# 重塑形变后的坐标以适应下一步操作y_new = y_new.reshape([self.num_batch, self.num_points, 1, self.width, self.height])y_new = y_new.permute(0, 3, 1, 4, 2)y_new = y_new.reshape([self.num_batch, self.num_points * self.width, 1 * self.height])x_new = x_new.reshape([self.num_batch, self.num_points, 1, self.width, self.height])x_new = x_new.permute(0, 3, 1, 4, 2)x_new = x_new.reshape([self.num_batch, self.num_points * self.width, 1 * self.height])return y_new, x_newelse:# 沿y轴形变的情况# 初始化网格的y坐标,均匀分布在-kernel_size/2到kernel_size/2之间y = torch.linspace(-int(self.num_points // 2), int(self.num_points // 2), int(self.num_points))# 初始化网格的x坐标,所有x坐标都为0x = torch.linspace(0, 0, 1)# 生成网格坐标y, x = torch.meshgrid(y, x)y_spread = y.reshape(-1, 1)x_spread = x.reshape(-1, 1)# 重复网格坐标以覆盖整个特征图的宽度和高度y_grid = y_spread.repeat([1, self.width * self.height])y_grid = y_grid.reshape([self.num_points, self.width, self.height])y_grid = y_grid.unsqueeze(0)  # 增加批次维度x_grid = x_spread.repeat([1, self.width * self.height])x_grid = x_grid.reshape([self.num_points, self.width, self.height])x_grid = x_grid.unsqueeze(0)  # 增加批次维度# 将基准中心坐标与网格坐标相加以生成新的形变坐标y_new = y_center + y_gridx_new = x_center + x_grid# 复制到每个批次y_new = y_new.repeat(self.num_batch, 1, 1, 1)x_new = x_new.repeat(self.num_batch, 1, 1, 1)# 将坐标转移到设备上(如GPU)y_new = y_new.to(self.device)x_new = x_new.to(self.device)# 处理x轴偏移量x_offset_new = x_offset.detach().clone()if if_offset:# 调整偏移数据以应用于所有批次和位置x_offset = x_offset.permute(1, 0, 2, 3)x_offset_new = x_offset_new.permute(1, 0, 2, 3)center = int(self.num_points // 2)# 中心位置保持不变,其他位置根据偏移量动态调整x_offset_new[center] = 0for index in range(1, center):x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(self.device)x_new = x_new.add(x_offset_new.mul(self.extend_scope))# 重塑形变后的坐标以适应下一步操作y_new = y_new.reshape([self.num_batch, 1, self.num_points, self.width, self.height])y_new = y_new.permute(0, 3, 1, 4, 2)y_new = y_new.reshape([self.num_batch, 1 * self.width, self.num_points * self.height])x_new = x_new.reshape([self.num_batch, 1, self.num_points, self.width, self.height])x_new = x_new.permute(0, 3, 1, 4, 2)x_new = x_new.reshape([self.num_batch, 1 * self.width, self.num_points * self.height])return y_new, x_new"""输入:输入特征图 [N,C,D,W,H];坐标映射 [N,K*D,K*W,K*H] 输出:[N,1,K*D,K*W,K*H]  形变后的特征图"""def _bilinear_interpolate_3D(self, input_feature, y, x):# 将坐标向量平铺并转换为浮点数y = y.reshape([-1]).float()x = x.reshape([-1]).float()# 定义网格边界zero = torch.zeros([]).int()max_y = self.width - 1max_x = self.height - 1# 计算网格的四个角的坐标y0 = torch.floor(y).int()y1 = y0 + 1x0 = torch.floor(x).int()x1 = x0 + 1# 限制坐标不超过特征图的边界y0 = torch.clamp(y0, zero, max_y)y1 = torch.clamp(y1, zero, max_y)x0 = torch.clamp(x0, zero, max_x)x1 = torch.clamp(x1, zero, max_x)# 展平输入特征图以便进行索引input_feature_flat = input_feature.flatten()input_feature_flat = input_feature_flat.reshape(self.num_batch, self.num_channels, self.width, self.height)input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)dimension = self.height * self.width# 计算每个批次的基准索引base = torch.arange(self.num_batch) * dimensionbase = base.reshape([-1, 1]).float()repeat = torch.ones([self.num_points * self.width * self.height]).unsqueeze(0)repeat = repeat.float()# 将基准索引复制以匹配特征点数量base = torch.matmul(base, repeat)base = base.reshape([-1])base = base.to(self.device)# 计算周围4个点的索引base_y0 = base + y0 * self.heightbase_y1 = base + y1 * self.heightindex_a0 = base_y0 - base + x0index_c0 = base_y0 - base + x1index_a1 = base_y1 - base + x0index_c1 = base_y1 - base + x1# 获取这四个点的值value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(self.device)value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(self.device)value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(self.device)value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(self.device)# 计算插值权重x0_float = x0.float()x1_float = x1.float()y0_float = y0.float()y1_float = y1.float()vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(self.device)vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(self.device)vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(self.device)vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(self.device)# 根据权重和四个角的值计算插值结果outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 + value_c1 * vol_c1)# 重塑输出以匹配输入的形状if self.morph == 0:outputs = outputs.reshape([self.num_batch, self.num_points * self.width, 1 * self.height, self.num_channels])outputs = outputs.permute(0, 3, 1, 2)else:outputs = outputs.reshape([self.num_batch, 1 * self.width, self.num_points * self.height, self.num_channels])outputs = outputs.permute(0, 3, 1, 2)return outputsdef deform_conv(self, input, offset, if_offset):# 获取形变的坐标映射y, x = self._coordinate_map_3D(offset, if_offset)# 应用双线性插值deformed_feature = self._bilinear_interpolate_3D(input, y, x)return deformed_featureif __name__ == '__main__':os.environ["CUDA_VISIBLE_DEVICES"] = '0'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")A = np.random.rand(4, 5, 6, 7)# A = np.ones(shape=(3, 2, 2, 3), dtype=np.float32)# print(A)A = A.astype(dtype=np.float32)A = torch.from_numpy(A)# print(A.shape)conv0 = DSConv(in_ch=5,out_ch=10,kernel_size=15,extend_scope=1,morph=0,if_offset=True,device=device)if torch.cuda.is_available():A = A.to(device)conv0 = conv0.to(device)out = conv0(A)print(out.shape)

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

相关文章:

  • 从简历筛选到面试管理:开发一站式智能招聘系统源码详解
  • JavaScript 性能优化全攻略:从基础到实战
  • 瑞芯微RK3288解决方案:高性能、高扩展性的嵌入式系统设计理念与应用分析
  • C++ 深入解析 数据结构中的 AVL树的插入 涉及的旋转规则
  • 小米 MiMo 开源:7B 参数凭什么 “叫板” AI行业巨头?
  • 【今日三题】ISBN号码(模拟) / kotori和迷宫(BFS最短路) / 矩阵最长递增路径(dfs)
  • 红黑树的应用场景 —— 进程调度 CFS 与内存管理
  • 视频编解码学习7之视频编码简介
  • 6. 进程控制
  • 初学者的AI智能体课程:构建AI智能体的十堂课
  • 在k8s中,如何实现服务的访问,k8s的ip是变化的,怎么保证能访问到我的服务
  • Perspective,数据可视化的超级引擎!
  • K8S常见问题汇总
  • PDF生成模块开发经验分享
  • [5-2] 对射式红外传感器计次旋转编码器计次 江协科技学习笔记(38个知识点)
  • XL32F001国产低成本单片机,24MHz主频,24KB Flash,3KB SRAM
  • 【探寻C++之旅】第十三章:红黑树
  • Python Cookbook-7.8 使用 Berkeley DB 数据库
  • TensorFlow 2.x入门实战:从零基础到图像分类项目
  • 物流无人机自动化装卸技术解析!
  • 上海一中院一审公开开庭审理被告人胡欣受贿案
  • “子宫内膜异位症”相关论文男性患者样本超六成?福建省人民医院发布情况说明
  • 习近平在俄罗斯媒体发表署名文章
  • 李云泽:将尽快推出支持小微企业民营企业融资一揽子政策
  • 潘功胜:央行将设立5000亿元服务消费与养老再贷款
  • 五月A股怎么买?券商金股电子权重第一,格力电器最热