nnUNet V2修改网络——加入GHPA 模块
更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题
阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。
EGE-UNet 是 UNet 的一个变体,专为皮肤病变分割而设计。它在 UNet 的基础上引入了 GHPA 模块和 GAB 模块,分别用于多维度特征提取和多尺度信息融合。GHPA 运用多轴分组和 Hadamard 乘积机制,能从不同视角提取病理信息,而 GAB 则通过分组聚合和卷积操作,有效整合不同层次和尺度的特征。这种创新设计使得 EGE-UNet 不仅在 ISIC2017 和 ISIC2018 数据集上取得卓越的分割效果,还大幅降低了参数和计算成本,成为低资源环境下理想的分割模型。
EGE-UNet官方代码仓库: https://github.com/JCruan519/EGE-UNet
GHPA 模块的加入,有效解决了传统自注意力机制计算量大的问题。它通过分组的方式,将复杂的多头自注意力机制简化为线性复杂度的 Hadamard 乘积操作。GHPA 将输入特征按通道维度均分为四组,分别在高度-宽度、通道-高度和通道-宽度轴上进行 Hadamard 乘积注意力操作,最后一组则通过深度可分离卷积进行处理。这种多轴分组的设计,使得模型能够从不同视角提取特征,进一步提升了分割的准确性和鲁棒性。
本文目录
- 一 准备工作
- 1. 安装dynamic-network-architectures
- 2. 生成nnUNetPlans.json文件
- 二 修改思路
- 1. GPHA模块结构
- 2. 查看 nnU-Net V2 网络结构
- 3. 替换过程
- 三 修改网络
- 1. 创建GHPA模块
- 2. 替换基础块
一 准备工作
1. 安装dynamic-network-architectures
点击链接,将其clone到本地后,进入文件夹内,pip install -e . 即可(注意-e后有个点)。
2. 生成nnUNetPlans.json文件
运行nnUNetv2_plan_and_preprocess命令,也是预处理命令,生成nnUNetPlans.json文件。
二 修改思路
由于读者预期替换的网络不一定是本文替换的网络,且所用数据集不一定是本文所用数据集,所以
本文主要介绍如何修改网络,使修改后的nnU-Net V2可以正常训练,给读者提供实践样例,不会涉及评估指标。
在对相同网络结构进行修改时,读者往往会有不同的修改思路 😃。在这一过程中,请读者以成功运行代码、是否便于读者后续回顾作为考量标准,确定修改思路。
1. GPHA模块结构
GHPA模块先将特征图通过layernorm层,再沿着通道维度将特征图分割为4份,即4份b * (c//4) * h * w大小的特征图,其中三份进行三个维度的Hadamard乘积操作,最后一份则通过卷积进行处理:
Hadamard 乘积矩阵形状 | 维度方向 |
---|---|
1 * (c//4) * h * w | 高度-宽度 |
1 * 1 * (c//4) * h | 通道-高度 |
1 * 1 * (c//4) * w | 通道-宽度 |
将四份操作后的特征图沿通道维度拼接(cat)起来,再一次通过layernorm层,卷积处理后返回特征图。
2. 查看 nnU-Net V2 网络结构
打开nnUNet \ DATASET \ nnUNet_preprocessed \ Dataset001_ACDC \ nnUNetPlans.json文件,查看configurations --> 2d --> architecture --> network_class_name字段,默认为dynamic_network_architectures.architectures.unet.PlainConvUNet。
根据network_class_name字段,找到PlainConvUNet类所在文件:dynamic-network-architectures-main \ dynamic_network_architectures \ architectures \ unet.py
PlainConvUNet类就是nnU-Net默认的U-Net,其结构由编码器和解码器两部分组成,很标准,很常见。具体代码、结构见PlainConvUNet类
3. 替换过程
我们的替换过程如下:
- 将nnU-Net V2基础块(基础块详见nnUNet V2代码——构建网络)修改为含GHPA的GHPA模块。
- nnU-Net V2的编码器第一层不变,其余层以及解码器所有层的基础块替换为GHPA模块。
三 修改网络
本次替换的一些设置
训练配置 | 2d |
更换的网络 | 加入EGE-UNet的GHPA模块 |
涉及的文件(加粗文件是要修改的文件):
dynamic_network_architectures/
|__ architectures/
| |__ resnet.py
| |__ unet.py
| |__ vgg.py
|__ building_blocks/
| |__ helper.py
| |__ plain_conv_encoder.py
| |__ regularization.py
| |__ residual.py
| |__ residual_encoders.py
| |__ simple_conv_blocks.py
| |__ unet_decoder.py
| |__ unet_residual_decoder.py
|__ initialization/
| |__ weight_init.py
本次修改为了定位修改位置,会粘贴额外的代码用于定位,替换部分会用注释标识
1. 创建GHPA模块
修改基础块(ConvDropoutNormReLU类)为GHPA模块,先改名字(用于区分原基础块,读者可自定义)为GHPAConvBlock,再改代码:
保留__init__函数的一部分,修改其余部分,保留部分用已注释标识:
# __init__函数
######################保留代码开始
self.input_channels = input_channels
self.output_channels = output_channels
stride = maybe_convert_scalar_to_list(conv_op, stride)
self.stride = stride
kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
######################保留代码结束
dim_in = input_channels
dim_out = output_channels
c_dim_in = input_channels//4
k_size = 3
padding = (k_size - 1) // 2
x=y=8
self.params_xy = nn.Parameter(torch.Tensor(1, c_dim_in, x, y), requires_grad=True)
nn.init.ones_(self.params_xy)
self.conv_xy = nn.Sequential(nn.Conv2d(c_dim_in, c_dim_in, kernel_size=k_size, padding=padding, groups=c_dim_in), nn.GELU(), nn.Conv2d(c_dim_in, c_dim_in, 1))
self.params_zx = nn.Parameter(torch.Tensor(1, 1, c_dim_in, x), requires_grad=True)
nn.init.ones_(self.params_zx)
self.conv_zx = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=padding, groups=c_dim_in), nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))
self.params_zy = nn.Parameter(torch.Tensor(1, 1, c_dim_in, y), requires_grad=True)
nn.init.ones_(self.params_zy)
self.conv_zy = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=padding, groups=c_dim_in), nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))
self.dw = nn.Sequential(
nn.Conv2d(c_dim_in, c_dim_in, 1),
nn.GELU(),
nn.Conv2d(c_dim_in, c_dim_in, kernel_size=3, padding=1, groups=c_dim_in)
)
self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.ldw = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
nn.GELU(),
nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=self.stride,
padding=[(i-1)//2 for i in kernel_size], bias=False)
)
forward函数全部修改:
# forward函数
x = self.norm1(x)
x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
#----------xy----------#
params_xy = self.params_xy
x1 = x1 * self.conv_xy(F.interpolate(params_xy, size=x1.shape[2:4],mode='bilinear', align_corners=True))
#----------zx----------#
x2 = x2.permute(0, 3, 1, 2)
params_zx = self.params_zx
x2 = x2 * self.conv_zx(F.interpolate(params_zx, size=x2.shape[2:4],mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x2 = x2.permute(0, 2, 3, 1)
#----------zy----------#
x3 = x3.permute(0, 2, 1, 3)
params_zy = self.params_zy
x3 = x3 * self.conv_zy(F.interpolate(params_zy, size=x3.shape[2:4],mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x3 = x3.permute(0, 2, 1, 3)
#----------dw----------#
x4 = self.dw(x4)
#----------concat----------#
x = torch.cat([x1,x2,x3,x4],dim=1)
#----------ldw----------#
x = self.norm2(x)
x = self.ldw(x)
return x
其余函数不变
2. 替换基础块
需要替换的类有StackedConvBlocks类、PlainConvEncoder类、UNetDecoder类。这三个类的结构见nnUNet V2代码——构建网络
为了区分原有的类,StackedConvBlocks类最好也改个名字,我改为了GHPAStackedConvBlocks。代码只需修改__init__函数的一部分,其余部分 + 其余函数不用修改:
self.convs = nn.Sequential(
########################已将下一行的ConvDropoutNormReLU类修改为GHPAConvBlock
GHPAConvBlock(
conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op,
norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
),
*[
########################已将下一行的ConvDropoutNormReLU类修改为GHPAConvBlock
GHPAConvBlock(
conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op,
norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
)
for i in range(1, num_convs)
]
)
PlainConvEncoder类作为编码器,第一层因为通道数量的原因,不做替换,其余层均替换为GHPAStackedConvBlocks,依旧是修改__init__函数的一部分,其余部分 + 其余函数不做修改:
stages = []
for s in range(n_stages):
stage_modules = []
if pool == 'max' or pool == 'avg':
if (isinstance(strides[s], int) and strides[s] != 1) or \
isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
conv_stride = 1
elif pool == 'conv':
conv_stride = strides[s]
else:
raise RuntimeError()
###############################替换开始
if s > 0:
stage_modules.append(GHPAStackedConvBlocks(
n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
))
else:
stage_modules.append(StackedConvBlocks(
n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
))
###############################替换结束
stages.append(nn.Sequential(*stage_modules))
input_channels = features_per_stage[s]
UNetDecoder类最简单,将__init__函数的一行代码修改即可:
#######################已将下一行的StackedConvBlocks类修改为GHPAStackedConvBlocks类
stages.append(GHPAStackedConvBlocks( ##就是这一行,剩下的用于定位位置
n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,
encoder.kernel_sizes[-(s + 1)], 1,
conv_bias,
norm_op,
norm_op_kwargs,
dropout_op,
dropout_op_kwargs,
nonlin,
nonlin_kwargs,
nonlin_first
))
本次修改完毕