当前位置: 首页 > news >正文

【即插即用涨点模块】EGA边缘引导注意力:有效保留高频边缘信息,提升分割精度,助力高效涨点【附源码+注释】

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于深度学习的行人跌倒检测系统】
9.【基于深度学习的PCB板缺陷检测系统】10.【基于深度学习的生活垃圾分类目标检测系统】
11.【基于深度学习的安全帽目标检测系统】12.【基于深度学习的120种犬类检测与识别系统】
13.【基于深度学习的路面坑洞检测系统】14.【基于深度学习的火焰烟雾检测系统】
15.【基于深度学习的钢材表面缺陷检测系统】16.【基于深度学习的舰船目标分类检测系统】
17.【基于深度学习的西红柿成熟度检测系统】18.【基于深度学习的血细胞检测与计数系统】
19.【基于深度学习的吸烟/抽烟行为检测系统】20.【基于深度学习的水稻害虫检测与识别系统】
21.【基于深度学习的高精度车辆行人检测与计数系统】22.【基于深度学习的路面标志线检测与识别系统】
23.【基于深度学习的智能小麦害虫检测识别系统】24.【基于深度学习的智能玉米害虫检测识别系统】
25.【基于深度学习的200种鸟类智能检测与识别系统】26.【基于深度学习的45种交通标志智能检测与识别系统】
27.【基于深度学习的人脸面部表情识别系统】28.【基于深度学习的苹果叶片病害智能诊断系统】
29.【基于深度学习的智能肺炎诊断系统】30.【基于深度学习的葡萄簇目标检测系统】
31.【基于深度学习的100种中草药智能识别系统】32.【基于深度学习的102种花卉智能识别系统】
33.【基于深度学习的100种蝴蝶智能识别系统】34.【基于深度学习的水稻叶片病害智能诊断系统】
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统】36.【基于深度学习的智能草莓病害检测与分割系统】
37.【基于深度学习的复杂场景下船舶目标检测系统】38.【基于深度学习的农作物幼苗与杂草检测系统】
39.【基于深度学习的智能道路裂缝检测与分析系统】40.【基于深度学习的葡萄病害智能诊断与防治系统】
41.【基于深度学习的遥感地理空间物体检测系统】42.【基于深度学习的无人机视角地面物体检测系统】
43.【基于深度学习的木薯病害智能诊断与防治系统】44.【基于深度学习的野外火焰烟雾检测系统】
45.【基于深度学习的脑肿瘤智能检测系统】46.【基于深度学习的玉米叶片病害智能诊断与防治系统】
47.【基于深度学习的橙子病害智能诊断与防治系统】48.【基于深度学习的车辆检测追踪与流量计数系统】
49.【基于深度学习的行人检测追踪与双向流量计数系统】50.【基于深度学习的反光衣检测与预警系统】
51.【基于深度学习的危险区域人员闯入检测与报警系统】52.【基于深度学习的高密度人脸智能检测与统计系统】
53.【基于深度学习的CT扫描图像肾结石智能检测系统】54.【基于深度学习的水果智能检测系统】
55.【基于深度学习的水果质量好坏智能检测系统】56.【基于深度学习的蔬菜目标检测与识别系统】
57.【基于深度学习的非机动车驾驶员头盔检测系统】58.【太基于深度学习的阳能电池板检测与分析系统】
59.【基于深度学习的工业螺栓螺母检测】60.【基于深度学习的金属焊缝缺陷检测系统】
61.【基于深度学习的链条缺陷检测与识别系统】62.【基于深度学习的交通信号灯检测识别】
63.【基于深度学习的草莓成熟度检测与识别系统】64.【基于深度学习的水下海生物检测识别系统】
65.【基于深度学习的道路交通事故检测识别系统】66.【基于深度学习的安检X光危险品检测与识别系统】
67.【基于深度学习的农作物类别检测与识别系统】68.【基于深度学习的危险驾驶行为检测识别系统】
69.【基于深度学习的维修工具检测识别系统】70.【基于深度学习的维修工具检测识别系统】
71.【基于深度学习的建筑墙面损伤检测系统】72.【基于深度学习的煤矿传送带异物检测系统】
73.【基于深度学习的老鼠智能检测系统】74.【基于深度学习的水面垃圾智能检测识别系统】
75.【基于深度学习的遥感视角船只智能检测系统】76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统】
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】
79.【基于深度学习的果园苹果检测与计数系统】80.【基于深度学习的半导体芯片缺陷检测系统】
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统】82.【基于深度学习的运动鞋品牌检测与识别系统】
83.【基于深度学习的苹果叶片病害检测识别系统】84.【基于深度学习的医学X光骨折检测与语音提示系统】
85.【基于深度学习的遥感视角农田检测与分割系统】86.【基于深度学习的运动品牌LOGO检测与识别系统】
87.【基于深度学习的电瓶车进电梯检测与语音提示系统】88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统】
89.【基于深度学习的医学CT图像肺结节智能检测与语音提示系统】90.【基于深度学习的舌苔舌象检测识别与诊断系统】
91.【基于深度学习的蛀牙智能检测与语音提示系统】92.【基于深度学习的皮肤癌智能检测与语音提示系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

      • 摘要
      • 创新点
      • 方法总结
      • EGA模块的作用
      • 总结
  • EGA源码与注释

在这里插入图片描述

论文地址:https://arxiv.org/abs/2309.03329
代码地址:https://github.com/UARK-AICV/MEGANet

摘要

在这里插入图片描述

本文提出了一种名为**多尺度边缘引导注意力网络(MEGANet)**的新方法,用于结肠镜图像中的息肉分割。息肉分割在结直肠癌的早期诊断中起着至关重要的作用,但由于背景复杂、息肉大小和形状多变以及边界模糊,分割任务面临诸多挑战。MEGANet通过结合经典的边缘检测技术和注意力机制,有效保留了高频信息(如边缘和边界),从而提高了分割精度。该方法在五个基准数据集上进行了广泛的实验,结果表明MEGANet在六种评估指标上均优于现有的最先进方法。

创新点

  1. 边缘引导注意力模块(EGA):MEGANet的核心创新是引入了EGA模块,该模块利用拉普拉斯算子来增强边缘信息,解决了弱边界分割问题。

  2. 多尺度边缘信息保留:EGA模块在多个尺度上操作,从低层到高层特征,确保模型能够关注边缘相关信息,从而在每个解码器层次上提升预测精度。

  3. 无参数方法:使用拉普拉斯算子作为无参数方法,有效提取和保留高频边缘信息,避免了传统CNN方法在边缘提取上的不足。

方法总结

MEGANet是一个端到端的框架,包含三个主要模块:

  1. 编码器:负责从输入图像中捕获和抽象特征。
  2. 解码器:专注于提取显著特征,生成与输入图像分辨率匹配的解码图。
  3. 边缘引导注意力模块(EGA):利用拉普拉斯算子增强边缘信息,确保在解码过程中保留高频细节。

MEGANet通过结合编码器、解码器和EGA模块,能够在多个尺度上保留边缘信息,从而提高了息肉分割的精度。

EGA模块的作用

在这里插入图片描述
EGA模块的主要作用是通过拉普拉斯算子提取和保留高频边缘信息,增强模型对弱边界的检测能力。具体来说,EGA模块在每一层接收三个输入:

  1. 编码器特征:来自编码器的视觉特征。
  2. 高频特征:通过拉普拉斯算子提取的边缘信息。
  3. 解码器预测特征:来自更高层的解码器预测特征。

EGA模块通过结合这些输入,生成一个融合特征,该特征能够突出边缘细节,并通过注意力机制引导模型关注关键区域,从而提升分割精度。此外,EGA模块还通过卷积块注意力模块(CBAM)进一步校准特征,确保模型能够准确捕捉边界和背景区域的相关性。

总结

MEGANet通过引入EGA模块,有效解决了息肉分割中的弱边界问题,显著提高了分割精度。该方法在多个数据集上的实验结果表明其优越性,为结直肠癌的早期诊断提供了有力的技术支持。

EGA源码与注释

# Github地址:https://github.com/UARK-AICV/MEGANet
# 论文:MEGANet: Multi-Scale Edge-Guided Attention Network for Weak Boundary Polyp Segmentation, WACV 2024
# 论文地址:https://arxiv.org/abs/2309.03329

import torch
import torch.nn.functional as F
import torch.nn as nn

# 定义高斯核函数,用于生成高斯模糊滤波器
def gauss_kernel(channels=3, cuda=True):
    # 创建一个5x5的高斯核
    kernel = torch.tensor([[1., 4., 6., 4., 1],
                           [4., 16., 24., 16., 4.],
                           [6., 24., 36., 24., 6.],
                           [4., 16., 24., 16., 4.],
                           [1., 4., 6., 4., 1.]])
    # 归一化高斯核
    kernel /= 256.
    # 将高斯核扩展到多个通道
    kernel = kernel.repeat(channels, 1, 1, 1)
    if cuda:
        # 如果使用GPU,将高斯核移动到GPU
        kernel = kernel.cuda()
    return kernel

# 定义下采样函数,通过每隔一个像素取值实现
def downsample(x):
    return x[:, :, ::2, ::2]

# 定义卷积高斯模糊函数,使用高斯核对图像进行模糊处理
def conv_gauss(img, kernel):
    # 使用反射填充图像边缘
    img = F.pad(img, (2, 2, 2, 2), mode='reflect')
    # 应用卷积操作进行高斯模糊
    out = F.conv2d(img, kernel, groups=img.shape[1])
    return out

# 定义上采样函数,通过插入零值实现
def upsample(x, channels):
    # 在每个像素之间插入零值
    cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
    cc = cc.permute(0, 1, 3, 2)
    cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
    x_up = cc.permute(0, 1, 3, 2)
    # 对上采样后的图像应用高斯模糊
    return conv_gauss(x_up, 4 * gauss_kernel(channels))

# 定义拉普拉斯金字塔的一个层级,计算图像与高斯模糊后上采样的图像的差异
def make_laplace(img, channels):
    # 对图像进行高斯模糊
    filtered = conv_gauss(img, gauss_kernel(channels))
    # 对模糊后的图像进行下采样
    down = downsample(filtered)
    # 对下采样后的图像进行上采样
    up = upsample(down, channels)
    # 如果上采样后的图像尺寸与原图不同,进行插值调整
    if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]:
        up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3]))
    # 计算原图与上采样后的图像的差异
    diff = img - up
    return diff

# 构建拉普拉斯金字塔,包含多个层级的差异图像和最终的下采样图像
def make_laplace_pyramid(img, level, channels):
    current = img
    pyr = []
    for _ in range(level):
        # 对当前图像计算拉普拉斯层级
        filtered = conv_gauss(current, gauss_kernel(channels))
        down = downsample(filtered)
        up = upsample(down, channels)
        if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
            up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
        diff = current - up
        pyr.append(diff)
        current = down
    # 最后一个层级为最终的下采样图像
    pyr.append(current)
    return pyr

# 定义通道注意力模块,用于计算通道级别的注意力权重
class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        # 定义MLP网络,用于计算通道注意力权重
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )

    def forward(self, x):
        # 计算平均池化后的通道注意力权重
        avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
        # 计算最大池化后的通道注意力权重
        max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
        # 将平均池化和最大池化后的权重相加
        channel_att_sum = avg_out + max_out

        # 将权重通过sigmoid函数归一化,并扩展到与输入相同的尺寸
        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        # 将权重应用到输入特征图
        return x * scale

# 定义空间注意力模块,用于计算空间级别的注意力权重
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        # 定义卷积层,用于计算空间注意力权重
        self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x):
        # 计算最大池化和平均池化后的特征图
        x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
        # 将特征图通过卷积层计算空间注意力权重
        x_out = self.spatial(x_compress)
        # 将权重通过sigmoid函数归一化
        scale = torch.sigmoid(x_out)  # broadcasting
        # 将权重应用到输入特征图
        return x * scale

# 定义CBAM模块,结合通道注意力和空间注意力
class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(CBAM, self).__init__()
        # 初始化通道注意力模块
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
        # 初始化空间注意力模块
        self.SpatialGate = SpatialGate()

    def forward(self, x):
        # 应用通道注意力
        x_out = self.ChannelGate(x)
        # 应用空间注意力
        x_out = self.SpatialGate(x_out)
        return x_out

# 定义Edge-Guided Attention Module(EGA)模块,用于结合边缘信息和预测结果进行特征融合
class EGA(nn.Module):
    def __init__(self, in_channels):
        super(EGA, self).__init__()

        # 定义特征融合卷积层
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True))

        # 定义注意力机制卷积层
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, 1, 3, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid())

        # 初始化CBAM模块
        self.cbam = CBAM(in_channels)

    def forward(self, edge_feature, x, pred):
        residual = x
        xsize = x.size()[2:]

        # 将预测结果通过sigmoid函数归一化
        pred = torch.sigmoid(pred)

        # 计算背景注意力权重
        background_att = 1 - pred
        # 应用背景注意力权重到特征图
        background_x = x * background_att

        # 计算边界注意力权重
        edge_pred = make_laplace(pred, 1)
        # 应用边界注意力权重到特征图
        pred_feature = x * edge_pred

        # 计算高频特征
        edge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True)
        # 应用高频特征到特征图
        input_feature = x * edge_input

        # 将背景特征、边界特征和高频特征进行拼接
        fusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1)
        # 应用特征融合卷积层
        fusion_feature = self.fusion_conv(fusion_feature)

        # 计算注意力权重
        attention_map = self.attention(fusion_feature)
        # 应用注意力权重到融合特征
        fusion_feature = fusion_feature * attention_map

        # 将融合特征与残差相加
        out = fusion_feature + residual
        # 应用CBAM模块
        out = self.cbam(out)
        return out

if __name__ == '__main__':
    # 模拟输入张量
    edge_feature = torch.randn(1, 1, 128, 128).cuda()
    x = torch.randn(1, 64, 128, 128).cuda()
    pred = torch.randn(1, 1, 128, 128).cuda()  # pred 通常是1通道

    # 实例化 EGA 类
    block = EGA(64).cuda()

    # 传递输入张量通过 EGA 实例
    output = block(edge_feature, x, pred)

    # 打印输入和输出的形状
    print(edge_feature.size())
    print(x.size())
    print(pred.size())
    print(output.size())

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

相关文章:

  • 告别硬编码:优雅管理状态常量与响应码
  • Ansible Facts变量
  • 相对论之光速
  • IP地址分配
  • Python 中用T = TypeVar(“T“)这个语法定义一个“类型变量”,属于类型提示系统的一部分
  • Java学习打卡-Day18-ArrayList、Vector、LinkedList
  • Ajax原理笔记
  • JDBC数据库连接池技术详解——从传统连接方式到高效连接管理
  • 零拷贝分析
  • LeetCode热题100JS(49/100)第九天|199|114|105|437|236
  • undo log ,redo log 和binlog的区别?
  • 使用 yum 命令安装 MariaDB 指南
  • 安卓edge://inspect 和 chrome://inspect调试移动设备上的网页
  • 瑞幸需要宇树科技
  • UNION,UNION ALL 的详细用法
  • 【leetcode hot 100 437】路径总和Ⅲ
  • Typora 使用教程(标题,段落,字体,列表,区块,代码,脚注,插入图片,表格,目录)
  • 什么是广播系统语言传输指数 STIPA
  • CCF CSP 第30次(2023.05)(1_仓库规划_C++)
  • 关于运行 npm run serve/dev 运行不起来,node_modules Git忽略不了等(问题)
  • 线下哪些商家支持无理由退货?查询方法公布
  • 训练孩子的科学思维,上海虹口推出“六个一百”旗舰工程
  • 我国成功发射遥感四十号02组卫星
  • 伊美第四轮核问题谈判开始
  • 母亲节书单|关于生育自由的未来
  • 1至4月全国铁路完成固定资产投资1947亿元,同比增长5.3%