当前位置: 首页 > news >正文

分割网络Segformer

序言:最近做一个项目,使用到了Segformer网络,并且处理完数据集,在4000张左右的分类数据集,跑segformer_b1轻量型模型,都有了不错的效果。具体最终的指数为mIoU:93.5; mPA:95.89; 
Accuracy:98.78 ,并且模型较小best.pt 大小52MB未量化,量化后15MB。推理速度也很快。于是就想来记录一下Segformer。

segformer项目链接:SegFormer - Hugging Face 机器学习平台(最下面也有测试demo代码)

segformer论文原文链接:[2105.15203] SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers

SegFormer网络结构图

SegFormer主要流程简述:

1.给定一个大小为H×W×3的图像,我们首先使用 重叠式分块将其划分为大小为4×4的块。

2.Encoder将这些图像块作为输入输入到分层Transformer编码器(其中引入Efficient Self-Attention 高效自注意力)中,以获取原始图像分辨率{1/4、1/8、1/16、1/32}处的多级特征。

3.Deconder将这些多级多层特征送入MLP中用于预测分割掩码。

SegFormer主要模块

1.Encoder

主要作用:用于提取粗粒度和细粒度的分层多尺度特征。

class SegFormerStage(nn.Module):def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,num_heads, expansion_ratio, patch_size, stride):super().__init__()# 重叠分块嵌入self.patch_embed = OverlapPatchEmbed(patch_size=patch_size,stride=stride,in_chans=in_channels,embed_dim=embed_dim)# 创建Transformer块self.blocks = nn.ModuleList([TransformerBlock(dim=embed_dim,reduction_ratio=reduction_ratio,num_heads=num_heads,expansion_ratio=expansion_ratio) for _ in range(num_blocks)])# 用于将序列转换回特征图的层self.norm = nn.LayerNorm(embed_dim)def forward(self, x):# 分块嵌入x, H, W = self.patch_embed(x)# 通过所有Transformer块for block in self.blocks:x = block(x)# 归一化x = self.norm(x)# 将序列转换回特征图格式 [B, H, W, C]B, N, C = x.shapex = x.permute(0, 2, 1).view(B, C, H, W)return x
1.1Overlap Patch Embeddings

①输入图像进行分割,使用卷积操作将输入图像分成大小为 patch_size 的块,并使用步幅为 stride 移动这些块以创建重叠块。

②然后对每个块进行一维向量化,摊平,并通过标准化层进行标准化。

tips:1.模块的输出包含一个形状为 (B, N, C) 的张量,对应(bitchsize,像素数量,嵌入维度)

        2.返回 H W,这是输入图像的大小,因为在解码时需要了解原始图像的大小。

class OverlapPatchEmbed(nn.Module):def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):super().__init__()patch_size = (patch_size, patch_size)  # 7*7self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,padding=(patch_size[0] // 2, patch_size[1] // 2))self.norm = nn.LayerNorm(embed_dim)def forward(self, x):x = self.proj(x)_, _, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, W
1.2Transformer Block
1.2.1Efficient Self-Attention 高效自注意

①引入自注意力。并且进行了序列缩减层从而降低了运算 复杂度

②复杂度由O(n^2)--->O(n^2/R),序列长度具体可缩减(N/R)

class EfficientSelfAttention(nn.Module):def __init__(self, dim, reduction_ratio, num_heads):super().__init__()self.reduction_ratio = reduction_ratioself.num_heads = num_headsself.head_dim = dim // num_heads# 序列缩减层self.reduction = nn.Sequential(nn.Linear(dim, dim * reduction_ratio),nn.LayerNorm(dim * reduction_ratio),nn.GELU(),nn.Linear(dim * reduction_ratio, dim // reduction_ratio))# 注意力机制self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim // reduction_ratio, dim * 2)  # Key和Value共享缩减def forward(self, x):B, N, C = x.shape  # [batch, seq_len, channels]# 1. 缩减Key序列长度k_reduced = self.reduction(x)  # [B, N/R, C/R]v_reduced = k_reduced  # 通常Value与Key共享缩减# 2. 生成Q/K/Vq = self.q(x).reshape(B, N, self.num_heads, self.head_dim)kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)k, v = kv.unbind(2)  # [B, N/R, num_heads, head_dim]# 3. 注意力计算(复杂度O(N²/R))attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = attn.softmax(dim=-1)output = (attn @ v).transpose(1, 2).reshape(B, N, C)return output
1.2.2Mix-FFN

①通道扩展MLP(全连接层),深度卷积注入位置信息,通道压缩MLP。

②替代传统位置编码,通过深度卷积泄露位置信息,解决测试分辨率与训练不一致时的性能下降问题。

class MixFFN(nn.Module):def __init__(self, in_features, expansion_ratio=4, kernel_size=3):super().__init__()hidden_features = int(in_features * expansion_ratio)# 1. 通道扩展MLPself.fc1 = nn.Linear(in_features, hidden_features)# 2. 深度卷积注入位置信息self.dwconv = nn.Conv2d(in_channels=hidden_features,out_channels=hidden_features,kernel_size=kernel_size,padding=kernel_size // 2,groups=hidden_features  # 深度可分离卷积)# 3. 激活函数self.act = nn.GELU()# 4. 通道压缩MLPself.fc2 = nn.Linear(hidden_features, in_features)def forward(self, x):# 输入形状: [batch, seq_len, channels]B, N, C = x.shapeH, W = int(N ** 0.5), int(N ** 0.5)  # 恢复2D形状# 通道扩展x = self.fc1(x)  # [B, N, hidden_C]# 转换为2D进行卷积x = x.permute(0, 2, 1).view(B, -1, H, W)  # [B, hidden_C, H, W]x = self.dwconv(x)  # 深度卷积泄露位置信息x = x.flatten(2).permute(0, 2, 1)  # 恢复序列 [B, N, hidden_C]# 激活与压缩x = self.act(x)x = self.fc2(x)  # [B, N, C]return x
2.Decoder

主要作用:利用mlp轻量级全多层感知机解码器,直接融合这些多层次特征并预测语义分割掩膜。

class SegFormerDecoder(nn.Module):def __init__(self, in_channels_list, unified_channels=256, num_classes=19):super().__init__()self.unified_channels = unified_channels# 1. 通道对齐MLP (每个阶段独立)self.align_mlps = nn.ModuleList([ChannelAlignMLP(in_ch, unified_channels)for in_ch in in_channels_list])# 2. 特征融合MLPself.fusion_mlp = FeatureFusionMLP(in_channels=4 * unified_channels,out_channels=unified_channels)# 3. 语义预测MLPself.seg_head = SegmentationHead(in_channels=unified_channels,num_classes=num_classes)def forward(self, features):# 步骤1: 通道对齐aligned_features = []for i, feat in enumerate(features):aligned = self.align_mlps[i](feat)aligned_features.append(aligned)# 步骤2: 上采样到1/4分辨率target_size = aligned_features[0].shape[2:]  # (H/4, W/4)upsampled_features = []for feat in aligned_features:# 双线性插值上采样up_feat = F.interpolate(feat,size=target_size,mode='bilinear',align_corners=False)upsampled_features.append(up_feat)# 步骤3: 通道维度拼接fused = torch.cat(upsampled_features, dim=1)  # [B, 4*C, H/4, W/4]# 步骤4: 特征融合fused = self.fusion_mlp(fused)  # [B, C, H/4, W/4]# 步骤5: 语义预测seg_mask = self.seg_head(fused)  # [B, num_classes, H/4, W/4]return seg_mask
2.1MLP Layer

①对于之前分层多尺度特征进行不同的上采样统一,然后融合不同分辨率的语义信息。

2.2MLP

①最后一个MLP用于生成像素级分类结果。

class ChannelAlignMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 1×1卷积等效于线性层,但支持2D特征图self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)class FeatureFusionMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 输入通道数为4*C(4个特征图拼接)self.fc = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.ReLU(inplace=True))def forward(self, x):return self.fc(x)

整体代码:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass OverlapPatchEmbed(nn.Module):def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):super().__init__()patch_size = (patch_size, patch_size)  # 7*7self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,padding=(patch_size[0] // 2, patch_size[1] // 2))self.norm = nn.LayerNorm(embed_dim)def forward(self, x):x = self.proj(x)_, _, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, Wclass EfficientSelfAttention(nn.Module):def __init__(self, dim, reduction_ratio, num_heads):super().__init__()self.reduction_ratio = reduction_ratioself.num_heads = num_headsself.head_dim = dim // num_heads# 序列缩减层self.reduction = nn.Sequential(nn.Linear(dim, dim * reduction_ratio),nn.LayerNorm(dim * reduction_ratio),nn.GELU(),nn.Linear(dim * reduction_ratio, dim // reduction_ratio))# 注意力机制self.q = nn.Linear(dim, dim)self.kv = nn.Linear(dim // reduction_ratio, dim * 2)  # Key和Value共享缩减def forward(self, x):B, N, C = x.shape  # [batch, seq_len, channels]# 1. 缩减Key序列长度k_reduced = self.reduction(x)  # [B, N/R, C/R]v_reduced = k_reduced  # 通常Value与Key共享缩减# 2. 生成Q/K/Vq = self.q(x).reshape(B, N, self.num_heads, self.head_dim)kv = self.kv(k_reduced).reshape(B, -1, 2, self.num_heads, self.head_dim)k, v = kv.unbind(2)  # [B, N/R, num_heads, head_dim]# 3. 注意力计算(复杂度O(N²/R))attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)attn = attn.softmax(dim=-1)output = (attn @ v).transpose(1, 2).reshape(B, N, C)return outputclass MixFFN(nn.Module):def __init__(self, in_features, expansion_ratio=4, kernel_size=3):super().__init__()hidden_features = int(in_features * expansion_ratio)# 1. 通道扩展MLPself.fc1 = nn.Linear(in_features, hidden_features)# 2. 深度卷积注入位置信息self.dwconv = nn.Conv2d(in_channels=hidden_features,out_channels=hidden_features,kernel_size=kernel_size,padding=kernel_size // 2,groups=hidden_features  # 深度可分离卷积)# 3. 激活函数self.act = nn.GELU()# 4. 通道压缩MLPself.fc2 = nn.Linear(hidden_features, in_features)def forward(self, x):# 输入形状: [batch, seq_len, channels]B, N, C = x.shapeH, W = int(N ** 0.5), int(N ** 0.5)  # 恢复2D形状# 通道扩展x = self.fc1(x)  # [B, N, hidden_C]# 转换为2D进行卷积x = x.permute(0, 2, 1).view(B, -1, H, W)  # [B, hidden_C, H, W]x = self.dwconv(x)  # 深度卷积泄露位置信息x = x.flatten(2).permute(0, 2, 1)  # 恢复序列 [B, N, hidden_C]# 激活与压缩x = self.act(x)x = self.fc2(x)  # [B, N, C]return xclass TransformerBlock(nn.Module):def __init__(self, dim, reduction_ratio, num_heads, expansion_ratio=4):super().__init__()# 归一化层self.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)# 注意力与FFNself.attn = EfficientSelfAttention(dim, reduction_ratio, num_heads)self.mixffn = MixFFN(dim, expansion_ratio)def forward(self, x):# 残差连接1: ESAx = x + self.attn(self.norm1(x))# 残差连接2: Mix-FFNx = x + self.mixffn(self.norm2(x))return xclass SegFormerStage(nn.Module):def __init__(self, in_channels, embed_dim, num_blocks, reduction_ratio,num_heads, expansion_ratio, patch_size, stride):super().__init__()# 重叠分块嵌入self.patch_embed = OverlapPatchEmbed(patch_size=patch_size,stride=stride,in_chans=in_channels,embed_dim=embed_dim)# 创建Transformer块self.blocks = nn.ModuleList([TransformerBlock(dim=embed_dim,reduction_ratio=reduction_ratio,num_heads=num_heads,expansion_ratio=expansion_ratio) for _ in range(num_blocks)])# 用于将序列转换回特征图的层self.norm = nn.LayerNorm(embed_dim)def forward(self, x):# 分块嵌入x, H, W = self.patch_embed(x)# 通过所有Transformer块for block in self.blocks:x = block(x)# 归一化x = self.norm(x)# 将序列转换回特征图格式 [B, H, W, C]B, N, C = x.shapex = x.permute(0, 2, 1).view(B, C, H, W)return xclass ChannelAlignMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 1×1卷积等效于线性层,但支持2D特征图self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)class FeatureFusionMLP(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()# 输入通道数为4*C(4个特征图拼接)self.fc = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.ReLU(inplace=True))def forward(self, x):return self.fc(x)class SegmentationHead(nn.Module):def __init__(self, in_channels, num_classes):super().__init__()# 1×1卷积实现像素级分类self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)def forward(self, x):return self.conv(x)class SegFormerDecoder(nn.Module):def __init__(self, in_channels_list, unified_channels=256, num_classes=19):super().__init__()self.unified_channels = unified_channels# 1. 通道对齐MLP (每个阶段独立)self.align_mlps = nn.ModuleList([ChannelAlignMLP(in_ch, unified_channels)for in_ch in in_channels_list])# 2. 特征融合MLPself.fusion_mlp = FeatureFusionMLP(in_channels=4 * unified_channels,out_channels=unified_channels)# 3. 语义预测MLPself.seg_head = SegmentationHead(in_channels=unified_channels,num_classes=num_classes)def forward(self, features):# 步骤1: 通道对齐aligned_features = []for i, feat in enumerate(features):aligned = self.align_mlps[i](feat)aligned_features.append(aligned)# 步骤2: 上采样到1/4分辨率target_size = aligned_features[0].shape[2:]  # (H/4, W/4)upsampled_features = []for feat in aligned_features:# 双线性插值上采样up_feat = F.interpolate(feat,size=target_size,mode='bilinear',align_corners=False)upsampled_features.append(up_feat)# 步骤3: 通道维度拼接fused = torch.cat(upsampled_features, dim=1)  # [B, 4*C, H/4, W/4]# 步骤4: 特征融合fused = self.fusion_mlp(fused)  # [B, C, H/4, W/4]# 步骤5: 语义预测seg_mask = self.seg_head(fused)  # [B, num_classes, H/4, W/4]return seg_maskclass SegFormer(nn.Module):def __init__(self, num_classes=3, version='b0'):super().__init__()# 根据版本选择配置if version == 'b0':config = {'stages': [# [in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride][3, 32, 2, 8, 1, 8, 7, 4],  # Stage1[32, 64, 2, 4, 2, 8, 3, 2],  # Stage2[64, 160, 2, 2, 5, 4, 3, 2],  # Stage3[160, 256, 2, 1, 8, 4, 3, 2]  # Stage4],'decoder_channels': 256}elif version == 'b1':config = {'stages': [[3, 64, 2, 8, 1, 8, 7, 4],[64, 128, 2, 4, 2, 8, 3, 2],[128, 320, 2, 2, 5, 4, 3, 2],[320, 512, 2, 1, 8, 4, 3, 2]],'decoder_channels': 256}else:raise ValueError(f"Unsupported version: {version}")# 创建编码器阶段self.stages = nn.ModuleList()in_channels_list = []  # 用于解码器的输入通道列表for i, stage_config in enumerate(config['stages']):in_channels, embed_dim, num_blocks, reduction_ratio, num_heads, expansion_ratio, patch_size, stride = stage_configstage = SegFormerStage(in_channels=in_channels,embed_dim=embed_dim,num_blocks=num_blocks,reduction_ratio=reduction_ratio,num_heads=num_heads,expansion_ratio=expansion_ratio,patch_size=patch_size,stride=stride)self.stages.append(stage)in_channels_list.append(embed_dim)# 创建解码器self.decoder = SegFormerDecoder(in_channels_list=in_channels_list,unified_channels=config['decoder_channels'],num_classes=num_classes)def forward(self, x):# 存储各阶段输出stage_outputs = []# 通过编码器各阶段for i, stage in enumerate(self.stages):# 第一个阶段输入为原始图像if i == 0:x = stage(x)# 后续阶段输入为前一阶段的输出else:x = stage(x)# 保存当前阶段的输出stage_outputs.append(x)# 通过解码器seg_mask = self.decoder(stage_outputs)# 上采样到原始分辨率seg_mask = F.interpolate(seg_mask, scale_factor=4, mode='bilinear', align_corners=False)return seg_mask# 测试模型
if __name__ == "__main__":# 创建模型model = SegFormer(num_classes=3, version='b0')print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")# 模拟输入input_tensor = torch.randn(2, 3, 512, 512)  # [batch, channels, height, width]# 前向传播output = model(input_tensor)print(f"输入尺寸: {input_tensor.shape}")print(f"输出尺寸: {output.shape}")  # 应该为 [2, 3, 512, 512]# 简单验证输出范围print(f"输出最小值: {output.min().item():.4f}, 最大值: {output.max().item():.4f}")# 可选: 保存模型结构图try:from torchviz import make_dotdot = make_dot(output, params=dict(model.named_parameters()))dot.render("segformer_model", format="png")print("模型结构图已保存为 segformer_model.png")except ImportError:print("未安装torchviz,跳过模型结构图生成")

最后来看一下deepseek对于这个模型训练后的指数评价XSWL

http://www.dtcms.com/a/275230.html

相关文章:

  • 需求跟踪深度解析:架构师视角下的全链路追溯体系
  • Vue性能监控
  • PreparedStatement 实现分页查询详解
  • 你以为大数据只是存?其实真正的“宝藏”藏在这招里——数据挖掘!
  • 自动评论+AI 写作+定时发布,这款媒体工具让自媒体人躺赚流量
  • 卸载软件总留一堆“垃圾”?这款免费神器,一键扫清注册表和文件残留!
  • BLOB 数据的插入与读取详解
  • 9月22日跨境电商高峰会都说了啥?郑州跨境电商发展机遇在哪?
  • Nginx的配置与使用
  • 多元思维模型:数据分析需要具备的四大能力?
  • 傅里叶方法求解正方形偏微分方程
  • Redis缓存三兄弟:穿透、击穿、雪崩全解析
  • 张量与维度
  • Grid网格布局完整功能介绍和示例演示
  • 2023年全国青少年信息素养大赛C++编程初中组决赛真题+答案解析
  • RestTemplate动态修改请求的url
  • 第一周JAVA——选择结构、循环结构、随机数、嵌套循环、数组(一维、二维)、方法、形参实参
  • 《每日AI-人工智能-编程日报》--7月11日
  • python知识:正则表达式快速入门案例:提取文章中所有的单词、提取文章中所有的数字、提取百度热搜的标题、提取ip地址
  • Web攻防-SSTI服务端模版注入利用分类语言引擎数据渲染项目工具挖掘思路
  • Umi-OCR 的 Docker安装(win制作镜像,Linux(Ubuntu Server 22.04)离线部署)
  • 数据集相关类代码回顾理解 | StratifiedShuffleSplit\transforms.ToTensor\Counter
  • 数据结构-双链表
  • 数字产品的专利战:要么布局称王,要么维权忙?
  • ABP VNext + Microsoft YARP:自定义反向代理与请求路由
  • 文件上传漏洞1-文件上传漏洞详细原理讲解与利用方式
  • 设计模式 - 面向对象原则:SOLID最佳实践
  • scrapy框架
  • 源表=电源+数字表?一文看懂SMU源表 2025-04-14
  • 大模型开发框架LangChain之函数调用