当前位置: 首页 > news >正文

nnUNet V2修改网络——加入GHPA 模块

更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题

阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。

EGE-UNet 是 UNet 的一个变体,专为皮肤病变分割而设计。它在 UNet 的基础上引入了 GHPA 模块和 GAB 模块,分别用于多维度特征提取和多尺度信息融合。GHPA 运用多轴分组和 Hadamard 乘积机制,能从不同视角提取病理信息,而 GAB 则通过分组聚合和卷积操作,有效整合不同层次和尺度的特征。这种创新设计使得 EGE-UNet 不仅在 ISIC2017 和 ISIC2018 数据集上取得卓越的分割效果,还大幅降低了参数和计算成本,成为低资源环境下理想的分割模型。

EGE-UNet官方代码仓库: https://github.com/JCruan519/EGE-UNet

GHPA 模块的加入,有效解决了传统自注意力机制计算量大的问题。它通过分组的方式,将复杂的多头自注意力机制简化为线性复杂度的 Hadamard 乘积操作。GHPA 将输入特征按通道维度均分为四组,分别在高度-宽度、通道-高度和通道-宽度轴上进行 Hadamard 乘积注意力操作,最后一组则通过深度可分离卷积进行处理。这种多轴分组的设计,使得模型能够从不同视角提取特征,进一步提升了分割的准确性和鲁棒性。
请添加图片描述

本文目录

  • 一 准备工作
    • 1. 安装dynamic-network-architectures
    • 2. 生成nnUNetPlans.json文件
  • 二 修改思路
    • 1. GPHA模块结构
    • 2. 查看 nnU-Net V2 网络结构
    • 3. 替换过程
  • 三 修改网络
    • 1. 创建GHPA模块
    • 2. 替换基础块

一 准备工作

1. 安装dynamic-network-architectures

点击链接,将其clone到本地后,进入文件夹内,pip install -e . 即可(注意-e后有个点)。

2. 生成nnUNetPlans.json文件

运行nnUNetv2_plan_and_preprocess命令,也是预处理命令,生成nnUNetPlans.json文件

二 修改思路

由于读者预期替换的网络不一定是本文替换的网络,且所用数据集不一定是本文所用数据集,所以
本文主要介绍如何修改网络,使修改后的nnU-Net V2可以正常训练,给读者提供实践样例,不会涉及评估指标。
在对相同网络结构进行修改时,读者往往会有不同的修改思路 😃。在这一过程中,请读者以成功运行代码、是否便于读者后续回顾作为考量标准,确定修改思路。

1. GPHA模块结构

GHPA模块先将特征图通过layernorm层,再沿着通道维度将特征图分割为4份,即4份b * (c//4) * h * w大小的特征图,其中三份进行三个维度的Hadamard乘积操作,最后一份则通过卷积进行处理:

Hadamard 乘积矩阵形状维度方向
1 * (c//4) * h * w高度-宽度
1 * 1 * (c//4) * h通道-高度
1 * 1 * (c//4) * w通道-宽度

将四份操作后的特征图沿通道维度拼接(cat)起来,再一次通过layernorm层,卷积处理后返回特征图。

2. 查看 nnU-Net V2 网络结构

打开nnUNet \ DATASET \ nnUNet_preprocessed \ Dataset001_ACDC \ nnUNetPlans.json文件,查看configurations --> 2d --> architecture --> network_class_name字段,默认为dynamic_network_architectures.architectures.unet.PlainConvUNet

根据network_class_name字段,找到PlainConvUNet类所在文件:dynamic-network-architectures-main \ dynamic_network_architectures \ architectures \ unet.py

PlainConvUNet类就是nnU-Net默认的U-Net,其结构由编码器和解码器两部分组成,很标准,很常见。具体代码、结构见PlainConvUNet类

3. 替换过程

我们的替换过程如下:

  1. 将nnU-Net V2基础块(基础块详见nnUNet V2代码——构建网络)修改为含GHPA的GHPA模块。
  2. nnU-Net V2的编码器第一层不变,其余层以及解码器所有层的基础块替换为GHPA模块。

三 修改网络

本次替换的一些设置

训练配置2d
更换的网络加入EGE-UNet的GHPA模块

涉及的文件(加粗文件是要修改的文件):

dynamic_network_architectures/
|__ architectures/
| |__ resnet.py
| |__ unet.py
| |__ vgg.py
|__ building_blocks/
| |__ helper.py
| |__ plain_conv_encoder.py
| |__ regularization.py
| |__ residual.py
| |__ residual_encoders.py
| |__ simple_conv_blocks.py
| |__ unet_decoder.py
| |__ unet_residual_decoder.py
|__ initialization/
| |__ weight_init.py

本次修改为了定位修改位置,会粘贴额外的代码用于定位,替换部分会用注释标识

1. 创建GHPA模块

修改基础块(ConvDropoutNormReLU类)为GHPA模块,先改名字(用于区分原基础块,读者可自定义)为GHPAConvBlock,再改代码:

保留__init__函数的一部分,修改其余部分,保留部分用已注释标识:

# __init__函数
######################保留代码开始
self.input_channels = input_channels
self.output_channels = output_channels
stride = maybe_convert_scalar_to_list(conv_op, stride)
self.stride = stride

kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
######################保留代码结束
dim_in = input_channels
dim_out = output_channels
c_dim_in = input_channels//4
k_size = 3
padding = (k_size - 1) // 2
x=y=8
self.params_xy = nn.Parameter(torch.Tensor(1, c_dim_in, x, y), requires_grad=True)
nn.init.ones_(self.params_xy)
self.conv_xy = nn.Sequential(nn.Conv2d(c_dim_in, c_dim_in, kernel_size=k_size, padding=padding, groups=c_dim_in), nn.GELU(), nn.Conv2d(c_dim_in, c_dim_in, 1))

self.params_zx = nn.Parameter(torch.Tensor(1, 1, c_dim_in, x), requires_grad=True)
nn.init.ones_(self.params_zx)
self.conv_zx = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=padding, groups=c_dim_in), nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))

self.params_zy = nn.Parameter(torch.Tensor(1, 1, c_dim_in, y), requires_grad=True)
nn.init.ones_(self.params_zy)
self.conv_zy = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=padding, groups=c_dim_in), nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))

self.dw = nn.Sequential(
        nn.Conv2d(c_dim_in, c_dim_in, 1),
        nn.GELU(),
        nn.Conv2d(c_dim_in, c_dim_in, kernel_size=3, padding=1, groups=c_dim_in)
)

self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')

self.ldw = nn.Sequential(
        nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
        nn.GELU(),
        nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=self.stride, 
                    padding=[(i-1)//2 for i in kernel_size], bias=False)
)

forward函数全部修改:

# forward函数
x = self.norm1(x)
x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
#----------xy----------#
params_xy = self.params_xy
x1 = x1 * self.conv_xy(F.interpolate(params_xy, size=x1.shape[2:4],mode='bilinear', align_corners=True))
#----------zx----------#
x2 = x2.permute(0, 3, 1, 2)
params_zx = self.params_zx
x2 = x2 * self.conv_zx(F.interpolate(params_zx, size=x2.shape[2:4],mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x2 = x2.permute(0, 2, 3, 1)
#----------zy----------#
x3 = x3.permute(0, 2, 1, 3)
params_zy = self.params_zy
x3 = x3 * self.conv_zy(F.interpolate(params_zy, size=x3.shape[2:4],mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x3 = x3.permute(0, 2, 1, 3)
#----------dw----------#
x4 = self.dw(x4)
#----------concat----------#
x = torch.cat([x1,x2,x3,x4],dim=1)
#----------ldw----------#
x = self.norm2(x)
x = self.ldw(x)
return x

其余函数不变

2. 替换基础块

需要替换的类有StackedConvBlocks类、PlainConvEncoder类、UNetDecoder类。这三个类的结构见nnUNet V2代码——构建网络

为了区分原有的类,StackedConvBlocks类最好也改个名字,我改为了GHPAStackedConvBlocks。代码只需修改__init__函数的一部分,其余部分 + 其余函数不用修改:

self.convs = nn.Sequential(
	########################已将下一行的ConvDropoutNormReLU类修改为GHPAConvBlock
    GHPAConvBlock(
        conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op,
        norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
    ),
    *[
    ########################已将下一行的ConvDropoutNormReLU类修改为GHPAConvBlock
        GHPAConvBlock(
            conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op,
            norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
        )
        for i in range(1, num_convs)
    ]
)

PlainConvEncoder类作为编码器,第一层因为通道数量的原因,不做替换,其余层均替换为GHPAStackedConvBlocks,依旧是修改__init__函数的一部分,其余部分 + 其余函数不做修改:

stages = []
for s in range(n_stages):
    stage_modules = []
    if pool == 'max' or pool == 'avg':
        if (isinstance(strides[s], int) and strides[s] != 1) or \
                isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
            stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
        conv_stride = 1
    elif pool == 'conv':
        conv_stride = strides[s]
    else:
        raise RuntimeError()
    ###############################替换开始
    if s > 0:
        stage_modules.append(GHPAStackedConvBlocks(
            n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
            conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
        ))
    else:
        stage_modules.append(StackedConvBlocks(
            n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
            conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
        ))
    ###############################替换结束
    stages.append(nn.Sequential(*stage_modules))
    input_channels = features_per_stage[s]

UNetDecoder类最简单,将__init__函数的一行代码修改即可:

#######################已将下一行的StackedConvBlocks类修改为GHPAStackedConvBlocks类
stages.append(GHPAStackedConvBlocks(	##就是这一行,剩下的用于定位位置
    n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,
    encoder.kernel_sizes[-(s + 1)], 1,
    conv_bias,
    norm_op,
    norm_op_kwargs,
    dropout_op,
    dropout_op_kwargs,
    nonlin,
    nonlin_kwargs,
    nonlin_first
))

本次修改完毕

相关文章:

  • 【Qt】 Data Visualization
  • P8752 [蓝桥杯 2021 省 B2] 特殊年份——string提取索引转换为值
  • ARM系统源码编译OpenCV 4.10.0(包含opencv_contrib)
  • vue3和vue2的组件开发有什么区别
  • 3.10 企业级AI内容生成引擎:从策略到落地的全链路技术指南
  • 【大模型】Transformers基础组件 - Tokenizer
  • 2024年职高单招或高考计算机类投档线
  • Python基于Django的人脸识别上课考勤管理系统【附源码】
  • flink jobgraph详细介绍
  • Golang GORM系列:GORM并发与连接池
  • 未来游戏:当人工智能重构虚拟世界的底层逻辑
  • 【mysql】数据类型介绍-空间类型-空间索引
  • Docker换源加速(更换镜像源)详细教程(2025.2最新可用镜像,全网最详细)
  • 机械学习基础-10.从时间序列数据中学习-数据建模与机械智能课程自留
  • LabVIEW的吞雨测控系统
  • 探讨如何加快 C# 多层循环的速度效率
  • 软件测试:定义和实质
  • 观望=没有!
  • 利用websocket检测网络连接稳定性
  • MySQL 清空表的数据
  • 游戏界面设计网站/权重查询
  • 京东网站的建设情况/百度关键词搜索排行榜
  • 网站建设技术规范及要求/小红书外链管家
  • 上海网站制作开发公司/长尾关键词快速排名软件
  • 网站制作职业/鸡西网站seo
  • 北京手机网站建设/网站建设推广专家服务