分割网络Segformer
序言:最近做一个项目,使用到了Segformer网络,并且处理完数据集,在4000张左右的分类数据集,跑segformer_b1轻量型模型,都有了不错的效果。具体最终的指数为mIoU:93.5; mPA:95.89;
Accuracy:98.78 ,并且模型较小best.pt 大小52MB未量化,量化后15MB。推理速度也很快。于是就想来记录一下Segformer。
segformer项目链接:SegFormer - Hugging Face 机器学习平台(最下面也有测试demo代码)
segformer论文原文链接:[2105.15203] SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
SegFormer网络结构图
SegFormer主要流程简述:
1.给定一个大小为H×W×3的图像,我们首先使用 重叠式分块将其划分为大小为4×4的块。
2.Encoder将这些图像块作为输入输入到分层Transformer编码器(其中引入Efficient Self-Attention 高效自注意力)中,以获取原始图像分辨率{1/4、1/8、1/16、1/32}处的多级特征。
3.Deconder将这些多级多层特征送入MLP中用于预测分割掩码。
SegFormer主要模块
1.Encoder
主要作用:用于提取粗粒度和细粒度的分层多尺度特征。
class SegFormerStage(nn.Module):def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,num_heads, expansion_ratio, patch_size, stride):super().__init__()# 重叠分块嵌入self.patch_embed = OverlapPatchEmbed(patch_size=patch_size,stride=stride,in_chans=in_channels,embed_dim=embed_dim)# 创建Transformer块self.blocks = nn.ModuleList([TransformerBlock(dim=embed_dim,reduction_ratio=reduction_ratio,num_heads=num_heads,expansion_ratio=expansion_ratio) for _ in range(num_blocks)])# 用于将序列转换回特征图的层self.norm = nn.LayerNorm(embed_dim)def forward(self, x):# 分块嵌入x, H, W = self.patch_embed(x)# 通过所有Transformer块for block in self.blocks:x = block(x)# 归一化x = self.norm(x)# 将序列转换回特征图格式 [B, H, W, C]B, N, C = x.shapex = x.permute(0, 2, 1).view(B, C, H, W)return x
1.1Overlap Patch Embeddings
①输入图像进行分割,使用卷积操作将输入图像分成大小为 patch_size 的块,并使用步幅为 stride 移动这些块以创建重叠块。
②然后对每个块进行一维向量化,摊平,并通过标准化层进行标准化。
tips:1.模块的输出包含一个形状为 (B, N, C) 的张量,对应(bitchsize,像素数量,嵌入维度)
2.返回 H W,这是输入图像的大小,因为在解码时需要了解原始图像的大小。
class OverlapPatchEmbed(nn.Module):def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):super().__init__()patch_size = (patch_size, patch_size) # 7*7self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,padding=(patch_size[0] // 2, patch_size[1] // 2))self.norm = nn.LayerNorm(embed_dim)def forward(self, x):x = self.proj(x)_, _, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, W
1.2Transformer Block
1.2.1Efficient Self-Attention 高效自注意
①引入自注意力。并且进行了序列缩减层从而降低了运算 复杂度
②复杂度由O(n^2)--->O(n^2/R),序列长度具体可缩减(N/R)
class EfficientSelfAttention(nn.Module):def __init__(self, dim, reduction_ratio, num_heads):super().__init__()self.reduction_ratio = reduction_ratioself.num_heads = num_headsself.head_dim = dim // num_heads# 序列缩减层self.reduction = nn.Sequential(nn.Linear(dim, dim * reduction_ratio),nn.LayerNorm(dim * reduction_ratio),nn.GELU(),nn.Linear(dim * reduction_ratio, dim // reduction_ratio))# 注意力机制self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim // reduction_ratio, dim * 2) # Key和Value共享缩减def forward(self, x):B, N, C = x.shape # [batch, seq_len, channels]# 1. 缩减Key序列长度k_reduced = self.reduction(x) # [B, N/R, C/R]v_reduced = k_reduced # 通常Value与Key共享缩减# 2. 生成Q/K/Vq = self.q(x).reshape(B, N, self.num_heads, self.head_dim)kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)k, v = kv.unbind(2) # [B, N/R, num_heads, head_dim]# 3. 注意力计算(复杂度O(N²/R))attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = attn.softmax(dim=-1)output = (attn @ v).transpose(1, 2).reshape(B, N, C)return output
1.2.2Mix-FFN
①通道扩展MLP(全连接层),深度卷积注入位置信息,通道压缩MLP。
②替代传统位置编码,通过深度卷积泄露位置信息,解决测试分辨率与训练不一致时的性能下降问题。
class MixFFN(nn.Module):def __init__(self, in_features, expansion_ratio=4, kernel_size=3):super().__init__()hidden_features = int(in_features * expansion_ratio)# 1. 通道扩展MLPself.fc1 = nn.Linear(in_features, hidden_features)# 2. 深度卷积注入位置信息self.dwconv = nn.Conv2d(in_channels=hidden_features,out_channels=hidden_features,kernel_size=kernel_size,padding=kernel_size // 2,groups=hidden_features # 深度可分离卷积)# 3. 激活函数self.act = nn.GELU()# 4. 通道压缩MLPself.fc2 = nn.Linear(hidden_features, in_features)def forward(self, x):# 输入形状: [batch, seq_len, channels]B, N, C = x.shapeH, W = int(N ** 0.5), int(N ** 0.5) # 恢复2D形状# 通道扩展x = self.fc1(x) # [B, N, hidden_C]# 转换为2D进行卷积x = x.permute(0, 2, 1).view(B, -1, H, W) # [B, hidden_C, H, W]x = self.dwconv(x) # 深度卷积泄露位置信息x = x.flatten(2).permute(0, 2, 1) # 恢复序列 [B, N, hidden_C]# 激活与压缩x = self.act(x)x = self.fc2(x) # [B, N, C]return x
2.Decoder
主要作用:利用mlp轻量级全多层感知机解码器,直接融合这些多层次特征并预测语义分割掩膜。
class SegFormerDecoder(nn.Module):def __init__(self, in_channels_list, unified_channels=256, num_classes=19):super().__init__()self.unified_channels = unified_channels# 1. 通道对齐MLP (每个阶段独立)self.align_mlps = nn.ModuleList([ChannelAlignMLP(in_ch, unified_channels)for in_ch in in_channels_list])# 2. 特征融合MLPself.fusion_mlp = FeatureFusionMLP(in_channels=4 * unified_channels,out_channels=unified_channels)# 3. 语义预测MLPself.seg_head = SegmentationHead(in_channels=unified_channels,num_classes=num_classes)def forward(self, features):# 步骤1: 通道对齐aligned_features = []for i, feat in enumerate(features):aligned = self.align_mlps[i](feat)aligned_features.append(aligned)# 步骤2: 上采样到1/4分辨率target_size = aligned_features[0].shape[2:] # (H/4, W/4)upsampled_features = []for feat in aligned_features:# 双线性插值上采样up_feat = F.interpolate(feat,size=target_size,mode='bilinear',align_corners=False)upsampled_features.append(up_feat)# 步骤3: 通道维度拼接fused = torch.cat(upsampled_features, dim=1) # [B, 4*C, H/4, W/4]# 步骤4: 特征融合fused = self.fusion_mlp(fused) # [B, C, H/4, W/4]# 步骤5: 语义预测seg_mask = self.seg_head(fused) # [B, num_classes, H/4, W/4]return seg_mask
2.1MLP Layer
①对于之前分层多尺度特征进行不同的上采样统一,然后融合不同分辨率的语义信息。
2.2MLP
①最后一个MLP用于生成像素级分类结果。
class ChannelAlignMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 1×1卷积等效于线性层,但支持2D特征图self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)class FeatureFusionMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 输入通道数为4*C(4个特征图拼接)self.fc = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.ReLU(inplace=True))def forward(self, x):return self.fc(x)
整体代码:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass OverlapPatchEmbed(nn.Module):def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):super().__init__()patch_size = (patch_size, patch_size) # 7*7self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,padding=(patch_size[0] // 2, patch_size[1] // 2))self.norm = nn.LayerNorm(embed_dim)def forward(self, x):x = self.proj(x)_, _, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, Wclass EfficientSelfAttention(nn.Module):def __init__(self, dim, reduction_ratio, num_heads):super().__init__()self.reduction_ratio = reduction_ratioself.num_heads = num_headsself.head_dim = dim // num_heads# 序列缩减层self.reduction = nn.Sequential(nn.Linear(dim, dim * reduction_ratio),nn.LayerNorm(dim * reduction_ratio),nn.GELU(),nn.Linear(dim * reduction_ratio, dim // reduction_ratio))# 注意力机制self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim // reduction_ratio, dim * 2) # Key和Value共享缩减def forward(self, x):B, N, C = x.shape # [batch, seq_len, channels]# 1. 缩减Key序列长度k_reduced = self.reduction(x) # [B, N/R, C/R]v_reduced = k_reduced # 通常Value与Key共享缩减# 2. 生成Q/K/Vq = self.q(x).reshape(B, N, self.num_heads, self.head_dim)kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)k, v = kv.unbind(2) # [B, N/R, num_heads, head_dim]# 3. 注意力计算(复杂度O(N²/R))attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = attn.softmax(dim=-1)output = (attn @ v).transpose(1, 2).reshape(B, N, C)return outputclass MixFFN(nn.Module):def __init__(self, in_features, expansion_ratio=4, kernel_size=3):super().__init__()hidden_features = int(in_features * expansion_ratio)# 1. 通道扩展MLPself.fc1 = nn.Linear(in_features, hidden_features)# 2. 深度卷积注入位置信息self.dwconv = nn.Conv2d(in_channels=hidden_features,out_channels=hidden_features,kernel_size=kernel_size,padding=kernel_size // 2,groups=hidden_features # 深度可分离卷积)# 3. 激活函数self.act = nn.GELU()# 4. 通道压缩MLPself.fc2 = nn.Linear(hidden_features, in_features)def forward(self, x):# 输入形状: [batch, seq_len, channels]B, N, C = x.shapeH, W = int(N ** 0.5), int(N ** 0.5) # 恢复2D形状# 通道扩展x = self.fc1(x) # [B, N, hidden_C]# 转换为2D进行卷积x = x.permute(0, 2, 1).view(B, -1, H, W) # [B, hidden_C, H, W]x = self.dwconv(x) # 深度卷积泄露位置信息x = x.flatten(2).permute(0, 2, 1) # 恢复序列 [B, N, hidden_C]# 激活与压缩x = self.act(x)x = self.fc2(x) # [B, N, C]return xclass TransformerBlock(nn.Module):def __init__(self, dim, reduction_ratio, num_heads, expansion_ratio=4):super().__init__()# 归一化层self.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)# 注意力与FFNself.attn = EfficientSelfAttention(dim, reduction_ratio, num_heads)self.mixffn = MixFFN(dim, expansion_ratio)def forward(self, x):# 残差连接1: ESAx = x + self.attn(self.norm1(x))# 残差连接2: Mix-FFNx = x + self.mixffn(self.norm2(x))return xclass SegFormerStage(nn.Module):def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,num_heads, expansion_ratio, patch_size, stride):super().__init__()# 重叠分块嵌入self.patch_embed = OverlapPatchEmbed(patch_size=patch_size,stride=stride,in_chans=in_channels,embed_dim=embed_dim)# 创建Transformer块self.blocks = nn.ModuleList([TransformerBlock(dim=embed_dim,reduction_ratio=reduction_ratio,num_heads=num_heads,expansion_ratio=expansion_ratio) for _ in range(num_blocks)])# 用于将序列转换回特征图的层self.norm = nn.LayerNorm(embed_dim)def forward(self, x):# 分块嵌入x, H, W = self.patch_embed(x)# 通过所有Transformer块for block in self.blocks:x = block(x)# 归一化x = self.norm(x)# 将序列转换回特征图格式 [B, H, W, C]B, N, C = x.shapex = x.permute(0, 2, 1).view(B, C, H, W)return xclass ChannelAlignMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 1×1卷积等效于线性层,但支持2D特征图self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)class FeatureFusionMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 输入通道数为4*C(4个特征图拼接)self.fc = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.ReLU(inplace=True))def forward(self, x):return self.fc(x)class SegmentationHead(nn.Module):def __init__(self, in_channels, num_classes):super().__init__()# 1×1卷积实现像素级分类self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)def forward(self, x):return self.conv(x)class SegFormerDecoder(nn.Module):def __init__(self, in_channels_list, unified_channels=256, num_classes=19):super().__init__()self.unified_channels = unified_channels# 1. 通道对齐MLP (每个阶段独立)self.align_mlps = nn.ModuleList([ChannelAlignMLP(in_ch, unified_channels)for in_ch in in_channels_list])# 2. 特征融合MLPself.fusion_mlp = FeatureFusionMLP(in_channels=4 * unified_channels,out_channels=unified_channels)# 3. 语义预测MLPself.seg_head = SegmentationHead(in_channels=unified_channels,num_classes=num_classes)def forward(self, features):# 步骤1: 通道对齐aligned_features = []for i, feat in enumerate(features):aligned = self.align_mlps[i](feat)aligned_features.append(aligned)# 步骤2: 上采样到1/4分辨率target_size = aligned_features[0].shape[2:] # (H/4, W/4)upsampled_features = []for feat in aligned_features:# 双线性插值上采样up_feat = F.interpolate(feat,size=target_size,mode='bilinear',align_corners=False)upsampled_features.append(up_feat)# 步骤3: 通道维度拼接fused = torch.cat(upsampled_features, dim=1) # [B, 4*C, H/4, W/4]# 步骤4: 特征融合fused = self.fusion_mlp(fused) # [B, C, H/4, W/4]# 步骤5: 语义预测seg_mask = self.seg_head(fused) # [B, num_classes, H/4, W/4]return seg_maskclass SegFormer(nn.Module):def __init__(self, num_classes=3, version='b0'):super().__init__()# 根据版本选择配置if version == 'b0':config = {'stages': [# [in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride][3, 32, 2, 8, 1, 8, 7, 4], # Stage1[32, 64, 2, 4, 2, 8, 3, 2], # Stage2[64, 160, 2, 2, 5, 4, 3, 2], # Stage3[160, 256, 2, 1, 8, 4, 3, 2] # Stage4],'decoder_channels': 256}elif version == 'b1':config = {'stages': [[3, 64, 2, 8, 1, 8, 7, 4],[64, 128, 2, 4, 2, 8, 3, 2],[128, 320, 2, 2, 5, 4, 3, 2],[320, 512, 2, 1, 8, 4, 3, 2]],'decoder_channels': 256}else:raise ValueError(f"Unsupported version: {version}")# 创建编码器阶段self.stages = nn.ModuleList()in_channels_list = [] # 用于解码器的输入通道列表for i, stage_config in enumerate(config['stages']):in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride = stage_configstage = SegFormerStage(in_channels=in_channels,embed_dim=embed_dim,num_blocks=num_blocks,reduction_ratio=reduction_ratio,num_heads=num_heads,expansion_ratio=expansion_ratio,patch_size=patch_size,stride=stride)self.stages.append(stage)in_channels_list.append(embed_dim)# 创建解码器self.decoder = SegFormerDecoder(in_channels_list=in_channels_list,unified_channels=config['decoder_channels'],num_classes=num_classes)def forward(self, x):# 存储各阶段输出stage_outputs = []# 通过编码器各阶段for i, stage in enumerate(self.stages):# 第一个阶段输入为原始图像if i == 0:x = stage(x)# 后续阶段输入为前一阶段的输出else:x = stage(x)# 保存当前阶段的输出stage_outputs.append(x)# 通过解码器seg_mask = self.decoder(stage_outputs)# 上采样到原始分辨率seg_mask = F.interpolate(seg_mask, scale_factor=4, mode='bilinear', align_corners=False)return seg_mask# 测试模型
if __name__ == "__main__":# 创建模型model = SegFormer(num_classes=3, version='b0')print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")# 模拟输入input_tensor = torch.randn(2, 3, 512, 512) # [batch, channels, height, width]# 前向传播output = model(input_tensor)print(f"输入尺寸: {input_tensor.shape}")print(f"输出尺寸: {output.shape}") # 应该为 [2, 3, 512, 512]# 简单验证输出范围print(f"输出最小值: {output.min().item():.4f}, 最大值: {output.max().item():.4f}")# 可选: 保存模型结构图try:from torchviz import make_dotdot = make_dot(output, params=dict(model.named_parameters()))dot.render("segformer_model", format="png")print("模型结构图已保存为 segformer_model.png")except ImportError:print("未安装torchviz,跳过模型结构图生成")
最后来看一下deepseek对于这个模型训练后的指数评价XSWL