EfficientVMamba代码略讲
文章目录
- EfficientVMamba代码
- EfficientVMamba
- step 1 从输入开始
- step 2 四阶段模型总览
- step 3 self.layers()模型实例化
- step 4 VSSBlocks到底是啥
- SS2D
EfficientVMamba代码
EfficientVMamba
step 1 从输入开始
class EfficientVSSM(VSSM):def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], # =========================d_state=16, dt_rank="auto", ssm_ratio=2.0, attn_drop_rate=0., shared_ssm=False,softmax_version=False,# =========================drop_rate=0., drop_path_rate=0.1, mlp_ratio=4.0,patch_norm=True, norm_layer=nn.LayerNorm,downsample_version: str = "v2",use_checkpoint=False, step_size=2,**kwargs,):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)if isinstance(dims, int):dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]self.embed_dim = dims[0]self.num_features = dims[-1]self.dims = dimsself.patch_embed = nn.Sequential(nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),Permute(0, 2, 3, 1),(norm_layer(self.embed_dim) if patch_norm else nn.Identity()), )dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay ruleself.layers = nn.ModuleList()for i_layer in range(self.num_layers):if downsample_version == "v2":downsample = self._make_downsample(self.dims[i_layer], self.dims[i_layer + 1], norm_layer=norm_layer,) if (i_layer < self.num_layers - 1) else nn.Identity()else:downsample = PatchMerging2D(self.dims[i_layer], self.dims[i_layer + 1], norm_layer=norm_layer,) if (i_layer < self.num_layers - 1) else nn.Identity()if i_layer < 2:self.layers.append(self._make_layer(dim = self.dims[i_layer],depth = depths[i_layer],drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],use_checkpoint=use_checkpoint,norm_layer=norm_layer,downsample=downsample,d_state=d_state,dt_rank=dt_rank,ssm_ratio=ssm_ratio,attn_drop_rate=attn_drop_rate,shared_ssm=shared_ssm,softmax_version=softmax_version,mlp_ratio=mlp_ratio,drop_rate=drop_rate,step_size=step_size))else:self.layers.append(nn.Sequential(Permute(0, 3, 1, 2),*[InvertedResidual(self.dims[i_layer], self.dims[i_layer], expand_ratio=4, se_ratio=0.125, drop_connect_rate=dpr[sum(depths[:i_layer]) + i])for i in range(depths[i_layer])],Permute(0, 2, 3, 1), downsample))self.classifier = nn.Sequential(OrderedDict(norm=norm_layer(self.num_features), # B,H,W,Cpermute=Permute(0, 3, 1, 2),avgpool=nn.AdaptiveAvgPool2d(1),flatten=nn.Flatten(1),head=nn.Linear(self.num_features, num_classes),))self.apply(self._init_weights)def forward(self, x: torch.Tensor):x = self.patch_embed(x) for idx, layer in enumerate(self.layers):x = layer(x)x = self.classifier(x)return x
模型首先将输入的图像 xxx 进行 patch_embedpatch\_embedpatch_embed 操作。具体实现如下:
patch_embed(x)patch\_embed(x)patch_embed(x)
self.patch_embed = nn.Sequential(nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),Permute(0, 2, 3, 1),(norm_layer(self.embed_dim) if patch_norm else nn.Identity()), )
接下来,经过嵌入以后的张量送入layers进行操作:
step 2 四阶段模型总览
self.layers(x)self.layers(x)self.layers(x)
self.layers = nn.ModuleList()for i_layer in range(self.num_layers):if downsample_version == "v2":downsample = self._make_downsample(self.dims[i_layer], self.dims[i_layer + 1], norm_layer=norm_layer,) if (i_layer < self.num_layers - 1) else nn.Identity()else:downsample = PatchMerging2D(self.dims[i_layer], self.dims[i_layer + 1], norm_layer=norm_layer,) if (i_layer < self.num_layers - 1) else nn.Identity()if i_layer < 2:self.layers.append(self._make_layer(dim = self.dims[i_layer],depth = depths[i_layer],drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],use_checkpoint=use_checkpoint,norm_layer=norm_layer,downsample=downsample,d_state=d_state,dt_rank=dt_rank,ssm_ratio=ssm_ratio,attn_drop_rate=attn_drop_rate,shared_ssm=shared_ssm,softmax_version=softmax_version,mlp_ratio=mlp_ratio,drop_rate=drop_rate,step_size=step_size))else:self.layers.append(nn.Sequential(Permute(0, 3, 1, 2),*[InvertedResidual(self.dims[i_layer], self.dims[i_layer], expand_ratio=4, se_ratio=0.125, drop_connect_rate=dpr[sum(depths[:i_layer]) + i])for i in range(depths[i_layer])],Permute(0, 2, 3, 1), downsample))
具体实现分析:
- self.layersself.layersself.layers 的构成
self.layersself.layersself.layers 是一个 nn.ModuleListnn.ModuleListnn.ModuleList,其内容根据网络层数(num_layersnum\_layersnum_layers)和配置动态生成,包含两种类型的层(取决于层索引i_layeri\_layeri_layer):- 前两层(i_layer<2i\_layer < 2i_layer<2):由 _make_layer\_make\_layer_make_layer 生成,包含多个 VSSBlockVSSBlockVSSBlock 模块和下采样层(downsampledownsampledownsample)。
if i_layer < 2:self.layers.append(self._make_layer(dim = self.dims[i_layer],depth = depths[i_layer],drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],use_checkpoint=use_checkpoint,norm_layer=norm_layer,downsample=downsample,d_state=d_state,dt_rank=dt_rank,ssm_ratio=ssm_ratio,attn_drop_rate=attn_drop_rate,shared_ssm=shared_ssm,softmax_version=softmax_version,mlp_ratio=mlp_ratio,drop_rate=drop_rate,step_size=step_size))
-
- 后两层(i_layer>=2i\_layer >= 2i_layer>=2):由多个 InvertedResidualInvertedResidualInvertedResidual 模块和下采样层组成的序列。
else:self.layers.append(nn.Sequential(Permute(0, 3, 1, 2),*[InvertedResidual(self.dims[i_layer], self.dims[i_layer], expand_ratio=4, se_ratio=0.125, drop_connect_rate=dpr[sum(depths[:i_layer]) + i])for i in range(depths[i_layer])],Permute(0, 2, 3, 1), downsample))
每层的具体结构在 EfficientVSSMEfficientVSSMEfficientVSSM 的 __init__\_\_init\_\___init__ 方法中定义,核心是将特征逐步下采样并提取更高维度的语义信息。
-
循环执行逻辑
- 初始输入 xxx 是经过 patch_embedpatch\_embedpatch_embed 处理后的特征张量(形状为 (B,H′,W′,embed_dim)(B, H', W', embed\_dim)(B,H′,W′,embed_dim))。
-
循环中,xxx 依次传入 self.layersself.layersself.layers 中的每一层:
-
对于前两层(含 VSSBlockVSSBlockVSSBlock):layer(x)layer(x)layer(x) 会先通过多个 VSSBlockVSSBlockVSSBlock 进行特征提取(包含 SSMSSMSSM 操作、卷积分支、注意力机制等),再通过 downsampledownsampledownsample 层下采样(缩小空间尺寸,增加通道数)。
-
对于后两层(含 InvertedResidualInvertedResidualInvertedResidual):layer(x)layer(x)layer(x) 会先通过多个 InvertedResidualInvertedResidualInvertedResidual 模块(类似 MobileNetMobileNetMobileNet 的倒置残差结构,含深度可分离卷积和注意力),再通过 downsampledownsampledownsample层下采样。
-
-
每次迭代后,xxx 更新为当前层的输出,作为下一层的输入。
-
- 初始输入 xxx 是经过 patch_embedpatch\_embedpatch_embed 处理后的特征张量(形状为 (B,H′,W′,embed_dim)(B, H', W', embed\_dim)(B,H′,W′,embed_dim))。
-
张量形状变化
假设输入图像经过 patch_embedpatch\_embedpatch_embed 后形状为 (B,H0,W0,C0)(C0=dims[0])(B, H_0, W_0, C_0)(C_0 = dims[0])(B,H0,W0,C0)(C0=dims[0]),则每一层处理后:- 空间维度(H,W)(H, W)(H,W):通过下采样层(步长为 2)逐步减半(如 H0→H0/2→H0/4→...H_0 → H_{0/2} → H_{0/4} → ...H0→H0/2→H0/4→...)。
- 通道维度(CCC):随层数增加按 dimsdimsdims 列表递增(如 C0→C1→C2→C3C_0 → C_1 → C_2 → C_3C0→C1→C2→C3,对应 dims=[96,192,384,768]dims = [96, 192, 384, 768]dims=[96,192,384,768])。
最终,经过所有层后,xxx 的形状为 (B,Hfinal,Wfinal,Cfinal)(B, H_{final}, W_{final}, C_{final})(B,Hfinal,Wfinal,Cfinal),其中 Cfinal=dims[−1]C_{final} = dims[-1]Cfinal=dims[−1](如 768),为后续分类器提供输入。
step 3 self.layers()模型实例化
_make_layers()\_make\_layers()_make_layers()
def _make_layer(dim=96, drop_path=[0.1, 0.1], use_checkpoint=False, norm_layer=nn.LayerNorm,downsample=nn.Identity(),# ===========================ssm_d_state=16,ssm_ratio=2.0,ssm_rank_ratio=2.0,ssm_dt_rank="auto", ssm_act_layer=nn.SiLU,ssm_conv=3,ssm_conv_bias=True,ssm_drop_rate=0.0, ssm_simple_init=False,forward_type="v2",# ===========================mlp_ratio=4.0,mlp_act_layer=nn.GELU,mlp_drop_rate=0.0,step_size=2,**kwargs,):depth = len(drop_path)blocks = []for d in range(depth):blocks.append(VSSBlock(hidden_dim=dim, drop_path=drop_path[d],norm_layer=norm_layer,ssm_d_state=ssm_d_state,ssm_ratio=ssm_ratio,ssm_rank_ratio=ssm_rank_ratio,ssm_dt_rank=ssm_dt_rank,ssm_act_layer=ssm_act_layer,ssm_conv=ssm_conv,ssm_conv_bias=ssm_conv_bias,ssm_drop_rate=ssm_drop_rate,ssm_simple_init=ssm_simple_init,forward_type=forward_type,mlp_ratio=mlp_ratio,mlp_act_layer=mlp_act_layer,mlp_drop_rate=mlp_drop_rate,use_checkpoint=use_checkpoint,step_size=step_size,))return nn.Sequential(OrderedDict(blocks=nn.Sequential(*blocks,),downsample=downsample,))
调用 _make_layer\_make\_layer_make_layer 后,会返回一个包含两个部分的有序序列(nn.Sequential(OrderedDict(...)))(nn.Sequential(OrderedDict(...)))(nn.Sequential(OrderedDict(...))):
- blocksblocksblocks:由 depthdepthdepth 个 VSSBlockVSSBlockVSSBlock 组成的序列(nn.Sequential(∗blocks))(nn.Sequential(*blocks))(nn.Sequential(∗blocks)),负责特征提取。
- downsampledownsampledownsample:下采样层,在所有 VSSBlockVSSBlockVSSBlock 执行完毕后对特征进行下采样。
例如,若 depth=2depth=2depth=2 且 downsampledownsampledownsample 为卷积下采样层,实例化后的结构为:
nn.Sequential(OrderedDict([('blocks', nn.Sequential(VSSBlock(hidden_dim=96, drop_path=0.1, ...), # 第一个块VSSBlock(hidden_dim=96, drop_path=0.1, ...) # 第二个块)),('downsample', nn.Sequential( # 下采样层(示例)Permute(0, 3, 1, 2),nn.Conv2d(96, 192, kernel_size=2, stride=2),Permute(0, 2, 3, 1),nn.LayerNorm(192)))])
)
step 4 VSSBlocks到底是啥
VSSBlock 是一个混合架构模块,通过 SSM 建模全局依赖、卷积提取局部特征、注意力增强关键通道,并结合残差连接和 MLP 提升表达能力。
class VSSBlock(nn.Module):def __init__(self,hidden_dim: int = 0,drop_path: float = 0,norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),# =============================ssm_d_state: int = 16,ssm_ratio=2.0,ssm_rank_ratio=2.0,ssm_dt_rank: Any = "auto",ssm_act_layer=nn.SiLU,ssm_conv: int = 3,ssm_conv_bias=True,ssm_drop_rate: float = 0,ssm_simple_init=False,forward_type="v2",# =============================mlp_ratio=4.0,mlp_act_layer=nn.GELU,mlp_drop_rate: float = 0.0,# =============================use_checkpoint: bool = False,step_size=2,**kwargs,):super().__init__()self.use_checkpoint = use_checkpointself.norm = norm_layer(hidden_dim)self.op = SS2D(d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio,ssm_rank_ratio=ssm_rank_ratio,dt_rank=ssm_dt_rank,act_layer=ssm_act_layer,# ==========================d_conv=ssm_conv,conv_bias=ssm_conv_bias,# ==========================dropout=ssm_drop_rate,# bias=False,# ==========================# dt_min=0.001,# dt_max=0.1,# dt_init="random",# dt_scale="random",# dt_init_floor=1e-4,simple_init=ssm_simple_init,# ==========================forward_type=forward_type,step_size=step_size,)self.conv_branch = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1, groups=hidden_dim),nn.BatchNorm2d(hidden_dim),nn.GELU(),nn.Conv2d(hidden_dim, hidden_dim, 1))self.se = BiAttn(hidden_dim) self.drop_path = DropPath(drop_path)self.mlp_branch = mlp_ratio > 0if self.mlp_branch:self.norm2 = norm_layer(hidden_dim)mlp_hidden_dim = int(hidden_dim * mlp_ratio)self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False)def _forward(self, input: torch.Tensor):x = self.norm(input)x_ssm = self.op(x)x_conv = self.conv_branch(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)x = self.se(x_ssm) + self.se(x_conv)x = input + self.drop_path(x)if self.mlp_branch:x = x + self.drop_path(self.mlp(self.norm2(x))) # FFNreturn xdef forward(self, input: torch.Tensor):if self.use_checkpoint:return checkpoint.checkpoint(self._forward, input)else:return self._forward(input)
-
首先,在模型接受了输入张量 xxx 以后,先进行归一化 self.norm()self.norm()self.norm() (对输入特征进行归一化(默认 LayerNormLayerNormLayerNorm),稳定训练过程。)接下来进入两个分支:
- 第一个分支进行 SSMSSMSSM 操作 x_ssm=self.op(x)x\_ssm = self.op(x)x_ssm=self.op(x)
- 第二个分支是卷积分支 self.conv_branchself.conv\_branchself.conv_branch
- 接下来对两个分支分别进行 self.seself.seself.se 操作
- 最后引入残差,输出
x_ssm=self.op(x)x\_ssm = self.op(x)x_ssm=self.op(x)
self.op = SS2D(d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio,ssm_rank_ratio=ssm_rank_ratio,dt_rank=ssm_dt_rank,act_layer=ssm_act_layer,# ==========================d_conv=ssm_conv,conv_bias=ssm_conv_bias,# ==========================dropout=ssm_drop_rate,# bias=False,# ==========================# dt_min=0.001,# dt_max=0.1,# dt_init="random",# dt_scale="random",# dt_init_floor=1e-4,simple_init=ssm_simple_init,# ==========================forward_type=forward_type,step_size=step_size,)
self.conv_branchself.conv\_branchself.conv_branch
self.conv_branch = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1, groups=hidden_dim),nn.BatchNorm2d(hidden_dim),nn.GELU(),nn.Conv2d(hidden_dim, hidden_dim, 1))
self.se()self.se(\ )self.se( )
self.se = BiAttn(hidden_dim)class BiAttn(nn.Module):def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid):super().__init__()reduce_channels = int(in_channels * act_ratio)self.norm = nn.LayerNorm(in_channels)self.global_reduce = nn.Linear(in_channels, reduce_channels)# self.local_reduce = nn.Linear(in_channels, reduce_channels)self.act_fn = act_fn()self.channel_select = nn.Linear(reduce_channels, in_channels)# self.spatial_select = nn.Linear(reduce_channels * 2, 1)self.gate_fn = gate_fn()def forward(self, x):ori_x = xx = self.norm(x)x_global = x.mean([1, 2], keepdim=True)x_global = self.act_fn(self.global_reduce(x_global))# x_local = self.act_fn(self.local_reduce(x))c_attn = self.channel_select(x_global)c_attn = self.gate_fn(c_attn) attn = c_attn out = ori_x * attnreturn out
- BiAttnBiAttnBiAttn 本质是一个通道注意力模块,通过学习全局信息来调整通道重要性,类似 SENetSENetSENet 中的挤压 - 激励(Squeeze−ExcitationSqueeze-ExcitationSqueeze−Excitation)机制
SS2D
class SS2D(nn.Module):def __init__(self,# basic dims ===========d_model=96,d_state=16,ssm_ratio=2.0,ssm_rank_ratio=2.0,dt_rank="auto",act_layer=nn.SiLU,# dwconv ===============d_conv=3, # < 2 means no conv conv_bias=True,# ======================dropout=0.0,bias=False,# dt init ==============dt_min=0.001,dt_max=0.1,dt_init="random",dt_scale=1.0,dt_init_floor=1e-4,simple_init=False,# ======================forward_type="v2",# ======================step_size=2,**kwargs,):"""ssm_rank_ratio would be used in the future..."""factory_kwargs = {"device": None, "dtype": None}super().__init__()d_expand = int(ssm_ratio * d_model)d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expandself.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rankself.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state # 20240109self.d_conv = d_convself.step_size = step_size# disable z act ======================================self.disable_z_act = forward_type[-len("nozact"):] == "nozact"if self.disable_z_act:forward_type = forward_type[:-len("nozact")]# softmax | sigmoid | norm ===========================if forward_type[-len("softmax"):] == "softmax":forward_type = forward_type[:-len("softmax")]self.out_norm = nn.Softmax(dim=1)elif forward_type[-len("sigmoid"):] == "sigmoid":forward_type = forward_type[:-len("sigmoid")]self.out_norm = nn.Sigmoid()else:self.out_norm = nn.LayerNorm(d_inner)# forward_type =======================================self.forward_core = dict(v0=self.forward_corev0,v0_seq=self.forward_corev0_seq,v1=self.forward_corev2,v2=self.forward_corev2,share_ssm=self.forward_corev0_share_ssm,share_a=self.forward_corev0_share_a,).get(forward_type, self.forward_corev2)self.K = 4 if forward_type not in ["share_ssm"] else 1self.K2 = self.K if forward_type not in ["share_a"] else 1# in proj =======================================self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs)self.act: nn.Module = act_layer()# conv =======================================if self.d_conv > 1:self.conv2d = nn.Conv2d(in_channels=d_expand,out_channels=d_expand,groups=d_expand,bias=conv_bias,kernel_size=d_conv,padding=(d_conv - 1) // 2,**factory_kwargs,)# rank ratio =====================================self.ssm_low_rank = Falseif d_inner < d_expand:self.ssm_low_rank = Trueself.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs)self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs)# x proj ============================self.x_proj = [nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs)for _ in range(self.K)]self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)del self.x_proj# dt proj ============================self.dt_projs = [self.dt_init(self.dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)for _ in range(self.K)]self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank)self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner)del self.dt_projs# A, D =======================================self.A_logs = self.A_log_init(self.d_state, d_inner, copies=self.K2, merge=True) # (K * D, N)self.Ds = self.D_init(d_inner, copies=self.K2, merge=True) # (K * D)# out proj =======================================self.out_proj = nn.Linear(d_expand, d_model, bias=bias, **factory_kwargs)self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()if simple_init:# simple init dt_projs, A_logs, Dsself.Ds = nn.Parameter(torch.ones((self.K2 * d_inner)))self.A_logs = nn.Parameter(torch.randn((self.K2 * d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank)))self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner))) @staticmethoddef dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)# Initialize special dt projection to preserve variance at initializationdt_init_std = dt_rank**-0.5 * dt_scaleif dt_init == "constant":nn.init.constant_(dt_proj.weight, dt_init_std)elif dt_init == "random":nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)else:raise NotImplementedError# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_maxdt = torch.exp(torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min)).clamp(min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))with torch.no_grad():dt_proj.bias.copy_(inv_dt)return dt_proj@staticmethoddef A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):# S4D real initializationA = repeat(torch.arange(1, d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=d_inner,).contiguous()A_log = torch.log(A) # Keep A_log in fp32if copies > 0:A_log = repeat(A_log, "d n -> r d n", r=copies)if merge:A_log = A_log.flatten(0, 1)A_log = nn.Parameter(A_log)A_log._no_weight_decay = Truereturn A_log@staticmethoddef D_init(d_inner, copies=-1, device=None, merge=True):# D "skip" parameterD = torch.ones(d_inner, device=device)if copies > 0:D = repeat(D, "n1 -> r n1", r=copies)if merge:D = D.flatten(0, 1)D = nn.Parameter(D) # Keep in fp32D._no_weight_decay = Truereturn D# only used to run previous versiondef forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()B, C, H, W = x.shapeL = H * WK = 4x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)xs = xs.float().view(B, -1, L) # (b, k * d, l)dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)Bs = Bs.float() # (b, k, d_state, l)Cs = Cs.float() # (b, k, d_state, l)As = -torch.exp(self.A_logs.float()) # (k * d, d_state)Ds = self.Ds.float() # (k * d)dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)# assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4# assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1out_y = selective_scan(xs, dts, As, Bs, Cs, Ds,delta_bias=dt_projs_bias,delta_softplus=True,).view(B, K, -1, L)# assert out_y.dtype == torch.floatinv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_yy = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)y = self.out_norm(y).view(B, H, W, -1)return (y.to(x.dtype) if to_dtype else y)def forward_corev0_seq(self, x: torch.Tensor, to_dtype=False, channel_first=False):def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()B, C, H, W = x.shapeL = H * WK = 4x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)xs = xs.float() # (b, k, d, l)dts = dts.contiguous().float() # (b, k, d, l)Bs = Bs.float() # (b, k, d_state, l)Cs = Cs.float() # (b, k, d_state, l)As = -torch.exp(self.A_logs.float()).view(K, -1, self.d_state) # (k, d, d_state)Ds = self.Ds.float().view(K, -1) # (k, d)dt_projs_bias = self.dt_projs_bias.float().view(K, -1) # (k, d)out_y = []for i in range(4):yi = selective_scan(xs[:, i], dts[:, i], As[i], Bs[:, i], Cs[:, i], Ds[i],delta_bias=dt_projs_bias[i],delta_softplus=True,).view(B, -1, L)out_y.append(yi)out_y = torch.stack(out_y, dim=1)assert out_y.dtype == torch.floatinv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_yy = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)y = self.out_norm(y).view(B, H, W, -1)return (y.to(x.dtype) if to_dtype else y)def forward_corev0_share_ssm(self, x: torch.Tensor, channel_first=False):"""we may conduct this ablation later, but not with v0."""...def forward_corev0_share_a(self, x: torch.Tensor, channel_first=False):"""we may conduct this ablation later, but not with v0."""...def forward_corev2(self, x: torch.Tensor, nrows=-1, channel_first=False, step_size=2):nrows = 1if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()if self.ssm_low_rank:x = self.in_rank(x)x = cross_selective_scan(x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,self.A_logs, self.Ds, getattr(self, "out_norm", None),nrows=nrows, delta_softplus=True, step_size=step_size)if self.ssm_low_rank:x = self.out_rank(x)return xdef forward(self, x: torch.Tensor, **kwargs):xz = self.in_proj(x)if self.d_conv > 1:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)if not self.disable_z_act:z = self.act(z)x = x.permute(0, 3, 1, 2).contiguous()x = self.act(self.conv2d(x)) # (b, d, h, w)else:if self.disable_z_act:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)x = self.act(x)else:xz = self.act(xz)x, z = xz.chunk(2, dim=-1) # (b, h, w, d)y = self.forward_core(x, channel_first=(self.d_conv > 1), step_size=self.step_size)y = y * zout = self.dropout(self.out_proj(y))return out
全流程
def forward(self, x: torch.Tensor, **kwargs):xz = self.in_proj(x)if self.d_conv > 1:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)if not self.disable_z_act:z = self.act(z)x = x.permute(0, 3, 1, 2).contiguous()x = self.act(self.conv2d(x)) # (b, d, h, w)else:if self.disable_z_act:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)x = self.act(x)else:xz = self.act(xz)x, z = xz.chunk(2, dim=-1) # (b, h, w, d)y = self.forward_core(x, channel_first=(self.d_conv > 1), step_size=self.step_size)y = y * zout = self.dropout(self.out_proj(y))return out
self.in_proj(x)self.in\_proj(x)self.in_proj(x)
self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs)
- in_projin\_projin_proj:把输入通道扩展到 d_expand∗2d\_expand*2d_expand∗2,后续一半给 xxx(主特征),一半给 zzz(门控)。
简单的示意图如下(以自然语言处理模型Mamba为例):
判断语句
if self.d_conv > 1:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)if not self.disable_z_act:z = self.act(z)x = x.permute(0, 3, 1, 2).contiguous()x = self.act(self.conv2d(x)) # (b, d, h, w)else:if self.disable_z_act:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)x = self.act(x)else:xz = self.act(xz)x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
第一分支
-
d_conv>1d\_conv > 1d_conv>1 这个判断,其实是一个 “是否要加局部卷积增强” 的开关。
如果 d_conv>1d\_conv > 1d_conv>1(有卷积):
if not self.disable_z_act:z = self.act(z)
那么,如果要对 zzz 进行激活,就进行SiLUSiLUSiLU激活
self.act: nn.Module = act_layer()
# 由于act_layer=nn.SiLU(_init_),所以这里的激活函数是SiLU激活
- 接下来调整维度顺序,并将 xxx 做成一个连续的张量
x = x.permute(0, 3, 1, 2).contiguous()
- 接下来将处理后的张量送入 2D2D2D 卷积层处理,并在这之后进行 SiLU 激活。
x = self.act(self.conv2d(x)) # (b, d, h, w)
#--------------------------------------------------------------if self.d_conv > 1:self.conv2d = nn.Conv2d(in_channels=d_expand, # d_expand = int(ssm_ratio * d_model) ; d_model=96 ; ssm_ratio=2.0out_channels=d_expand,groups=d_expand,bias=conv_bias, # conv_bias=Truekernel_size=d_conv, # d_conv=3, < 2 means no conv padding=(d_conv - 1) // 2,**factory_kwargs,)
第二分支
else:if self.disable_z_act:x, z = xz.chunk(2, dim=-1) # (b, h, w, d)x = self.act(x)else:xz = self.act(xz)x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
- 如果不对 zzz 进行激活,那么直接将 xzxzxz 张量按照最后一个维度分开,并且不对 xxx 进行卷积操作,直接进行激活。
- 如果对 zzz 进行激活,那么直接对张量 xzxzxz 进行激活,并在后续将其分开成为单独的 x、zx、zx、z 。
作用于当前输出
y = self.forward_core(x, channel_first=(self.d_conv > 1), step_size=self.step_size)y = y * zout = self.dropout(self.out_proj(y))return out
self.forward_core()self.forward\_core(\ )self.forward_core( )
self.forward_core = dict(v0=self.forward_corev0,v0_seq=self.forward_corev0_seq,v1=self.forward_corev2,v2=self.forward_corev2,share_ssm=self.forward_corev0_share_ssm,share_a=self.forward_corev0_share_a,).get(forward_type, self.forward_corev2)
上面做的事是:
- 如果 forward_type="v0"→self.forward_core=self.forward_corev0forward\_type="v0" → self.forward\_core = self.forward\_corev0forward_type="v0"→self.forward_core=self.forward_corev0
- 如果 forward_type="v0_seq"→self.forward_core=self.forward_corev0_seqforward\_type="v0\_seq" → self.forward\_core = self.forward\_corev0\_seqforward_type="v0_seq"→self.forward_core=self.forward_corev0_seq
- 如果 forward_type="v1"或"v2"→self.forward_core=self.forward_corev2forward\_type="v1" 或 "v2" → self.forward\_core = self.forward\_corev2forward_type="v1"或"v2"→self.forward_core=self.forward_corev2
- 如果 forward_type="share_ssm"→self.forward_core=self.forward_corev0_share_ssmforward\_type="share\_ssm" → self.forward\_core = self.forward\_corev0\_share\_ssmforward_type="share_ssm"→self.forward_core=self.forward_corev0_share_ssm
- 如果 forward_type="share_a"→self.forward_core=self.forward_corev0_share_aforward\_type="share\_a" → self.forward\_core = self.forward\_corev0\_share\_aforward_type="share_a"→self.forward_core=self.forward_corev0_share_a
- 如果没匹配到 → 默认 self.forward_core=self.forward_corev2self.forward\_core = self.forward\_corev2self.forward_core=self.forward_corev2
所以 self.forward_coreself.forward\_coreself.forward_core 是一个“函数指针”,代表具体哪种 forwardforwardforward 算法。
forward_corev0forward\_corev0forward_corev0 解析
def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()B, C, H, W = x.shapeL = H * WK = 4x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)xs = xs.float().view(B, -1, L) # (b, k * d, l)dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)Bs = Bs.float() # (b, k, d_state, l)Cs = Cs.float() # (b, k, d_state, l)As = -torch.exp(self.A_logs.float()) # (k * d, d_state)Ds = self.Ds.float() # (k * d)dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)# assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4# assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1out_y = selective_scan(xs, dts, As, Bs, Cs, Ds,delta_bias=dt_projs_bias,delta_softplus=True,).view(B, K, -1, L)# assert out_y.dtype == torch.floatinv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_yy = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)y = self.out_norm(y).view(B, H, W, -1)return (y.to(x.dtype) if to_dtype else y)
- 调整通道顺序
if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
L = H * W
K = 4
首先,如果输入是 (B,H,W,C)(B, H, W, C)(B,H,W,C) 格式,先转换成 (B,C,H,W)(B, C, H, W)(B,C,H,W)。
LLL:序列长度,H∗WH*WH∗W,用于后续展平。
K=4K=4K=4:表示后续操作会生成 4 条“扫描路径”
- 构建多方向序列
x_hwwh = torch.stack([x.view(B, -1, L),torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (B, K, D, L)
- 水平方向正向扫描
x.view(B, -1, L)
将序列展平 (H,W)→(L,)(H, W) → (L,)(H,W)→(L,)
- 垂直方向正向扫描
torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)
那么首先,将张量 xxx 的第二维和第三维交换位置,这样展平成序列以后就变成了垂直方向上的扫描。这样,模型经过两次正向"扫描",并在第二个维度上进行堆叠得到最后的张量。
x_堆叠张量 = torch.stack([水平正向张量,调换维度后展平的垂直正向张量],按照dim=1进行堆叠)扩展一个维度变成(B, 2, -1, L)
- 构造反向序列
torch.flip(x_hwwh, dims=[-1])
torch.flip(...,dims=[−1])torch.flip(..., dims=[-1])torch.flip(...,dims=[−1]) → 把 LLL 反过来。这样就得到竖直方向上与水平方向上的反向序列。
- 拼接正方向与反转方向序列
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (B, K, D, L)
- 线性投影
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
其中,self.x_proj_weightself.x\_proj\_weightself.x_proj_weight
self.x_proj = [nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs)for _ in range(self.K)]
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))
# (K, N, inner); 投影矩阵 self.x_proj_weight:形状 (K, C_out, D) ;
# C_out = dt_rank + d_state + d_state
del self.x_proj
那么根据上述信息我们知道上述投影操作的本质是将四方向序列(水平正、反;竖直正、反)的张量与张量投影权重进行矩阵乘法。也即对每条扫描路径 kkk 分别做矩阵乘法。也即:
xs[b, k, :, l] @ x_proj_weight[k].T → 输出 x_dbl[b, k, :, l]
- 拆分 delta/B/Cdelta / B / Cdelta/B/C
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
'''
dt_rank="auto",
self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank,
d_state=16,
self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state
'''
将 x_dblx\_dblx_dbl 按照 dim=2dim=2dim=2 维度拆分。dt_rank+d_state+d_state=C_outdt\_rank + d\_state + d\_state = C\_outdt_rank+d_state+d_state=C_out,所以正好拆完。上述拆分完的依据来源于:self.x_proj = [ nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) for _ in range(self.K) ]
中的 (self.dt_rank + self.d_state * 2)
- dtsdtsdts投影
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
- 作用:用每条路径自己的投影矩阵 将 rrr 维度映射到 ddd 维度。
- 输出:dts(B,K,d,L)dts (B, K, d, L)dts(B,K,d,L):每条路径的 deltadeltadelta 时间系数,用于状态更新
- 调整形状
xs = xs.float().view(B, -1, L) # (B, K*d, L)
dts = dts.contiguous().float().view(B, -1, L) # (B, K*d, L)
将路径维度 KKK 与通道维度 ddd 拼在一起,方便传入 SelectiveScanSelectiveScanSelectiveScan。SelectiveScanSelectiveScanSelectiveScan 接口要求输入 (B,D,L)(B, D, L)(B,D,L),这里的 D=K∗dD = K * dD=K∗d
Bs = Bs.float() # (B, K, d_state, L)
Cs = Cs.float() # (B, K, d_state, L)
- BsBsBs:每条路径的输入投影,用于计算递归状态
- CsCsCs:每条路径的输出投影,用于生成最终输出
As = -torch.exp(self.A_logs.float()) # (K*d, d_state)
- AAA 是每条扫描路径的状态转移矩阵
- 使用 −exp()-exp(\ )−exp( ) 保证递归过程收敛(常用做稳定化)
Ds = self.Ds.float() # (K*d)
- 直接跳跃/残差连接的权重
- 在 SelectiveScanSelectiveScanSelectiveScan 中,输出 =C∗x+D∗u= C * x + D * u=C∗x+D∗u
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (K*d)
- deltadeltadelta 的偏置
- 在计算 deltadeltadelta 时加到 dtsdtsdts 上:delta=delta+dt_projs_biasdelta = delta + dt\_projs\_biasdelta=delta+dt_projs_bias
- 如何作用于输出
调用 selective_scanselective\_scanselective_scan(也就是 SelectiveScan.applySelectiveScan.applySelectiveScan.apply)来做真正的递归计算
out_y = selective_scan(xs, dts, As, Bs, Cs, Ds,delta_bias=dt_projs_bias,delta_softplus=True,).view(B, K, -1, L)
其中,selective_scanselective\_scanselective_scan指的是:
def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)
具体实现在mamba/mamba_ssm/ops/selective_scan_interface.py
- 后续处理
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
y = self.out_norm(y).view(B, H, W, -1)
- 如下
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
的主要目的是进行四条扫描路径的融合 - 接下来调整维度顺序
y.transpose(dim0=1, dim1=2).contiguous()
- 最后进行归一化,并将张量重新
reshape 回 (B, H, W, C)
,这里的归一化指的是层归一化
forward_corev0_seqforward\_corev0\_seqforward_corev0_seq 解析
def forward_corev0_seq(self, x: torch.Tensor, to_dtype=False, channel_first=False):def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()B, C, H, W = x.shapeL = H * WK = 4x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)xs = xs.float() # (b, k, d, l)dts = dts.contiguous().float() # (b, k, d, l)Bs = Bs.float() # (b, k, d_state, l)Cs = Cs.float() # (b, k, d_state, l)As = -torch.exp(self.A_logs.float()).view(K, -1, self.d_state) # (k, d, d_state)Ds = self.Ds.float().view(K, -1) # (k, d)dt_projs_bias = self.dt_projs_bias.float().view(K, -1) # (k, d)out_y = []for i in range(4):yi = selective_scan(xs[:, i], dts[:, i], As[i], Bs[:, i], Cs[:, i], Ds[i],delta_bias=dt_projs_bias[i],delta_softplus=True,).view(B, -1, L)out_y.append(yi)out_y = torch.stack(out_y, dim=1)assert out_y.dtype == torch.floatinv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_yy = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)y = self.out_norm(y).view(B, H, W, -1)return (y.to(x.dtype) if to_dtype else y)
与前者的不同主要体现在:
- forward_corev0forward\_corev0forward_corev0 将所有路径的张量展平并一次性向量化调用 selective_scanselective\_scanselective_scan,数据形状如 (B,K∗D,L)(B, K*D, L)(B,K∗D,L),并行度高但灵活性低;
- 而 forward_corev0_seqforward\_corev0\_seqforward_corev0_seq 保持 (B,K,D,L)(B, K, D, L)(B,K,D,L) 的原始形状,循环逐条处理每条路径,灵活性高且方便调试,但并行度较低。
- forward_corev0forward\_corev0forward_corev0:
out_y = selective_scan(xs, dts, As, Bs, Cs, Ds,delta_bias=dt_projs_bias,delta_softplus=True,
).view(B, K, -1, L)
- forward_corev0_seqforward\_corev0\_seqforward_corev0_seq :
out_y = []
for i in range(4):yi = selective_scan(xs[:, i], dts[:, i],As[i], Bs[:, i], Cs[:, i], Ds[i],delta_bias=dt_projs_bias[i],delta_softplus=True,).view(B, -1, L)out_y.append(yi)
out_y = torch.stack(out_y, dim=1)
forward_corev2forward\_corev2forward_corev2 解析
def forward_corev2(self, x: torch.Tensor, nrows=-1, channel_first=False, step_size=2):nrows = 1if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()if self.ssm_low_rank:x = self.in_rank(x)x = cross_selective_scan(x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,self.A_logs, self.Ds, getattr(self, "out_norm", None),nrows=nrows, delta_softplus=True, step_size=step_size)if self.ssm_low_rank:x = self.out_rank(x)return x
- 通道调整
if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()
- 低秩映射
if self.ssm_low_rank:x = self.in_rank(x)
其中self.ssm_low_rank
与self.in_rank
指的是:
self.ssm_low_rank = Falseif d_inner < d_expand:self.ssm_low_rank = Trueself.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs)self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs)
'''d_expand = int(ssm_ratio * d_model) # 192 d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expand # 192
'''
- cross_selective_scancross\_selective\_scancross_selective_scan
x = cross_selective_scan(x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,self.A_logs, self.Ds, getattr(self, "out_norm", None),nrows=nrows, delta_softplus=True, step_size=step_size)
其具体实现如下:
def cross_selective_scan(x: torch.Tensor=None, x_proj_weight: torch.Tensor=None,x_proj_bias: torch.Tensor=None,dt_projs_weight: torch.Tensor=None,dt_projs_bias: torch.Tensor=None,A_logs: torch.Tensor=None,Ds: torch.Tensor=None,out_norm: torch.nn.Module=None,nrows = -1,delta_softplus = True,to_dtype=True,step_size = 2,
):B, D, H, W = x.shapeD, N = A_logs.shapeK, D, R = dt_projs_weight.shapeL = H * Wif nrows < 1:if D % 4 == 0:nrows = 4elif D % 3 == 0:nrows = 3elif D % 2 == 0:nrows = 2else:nrows = 1# H * Wori_h, ori_w = H, Wxs = EfficientScan.apply(x, step_size) # [B, C, H*W] -> [B, 4, C, H//w * W//w]# H//w * W//wH = math.ceil(H / step_size)W = math.ceil(W / step_size)L = H * Wx_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)if x_proj_bias is not None:x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)xs = xs.view(B, -1, L).to(torch.float)dts = dts.contiguous().view(B, -1, L).to(torch.float)As = -torch.exp(A_logs.to(torch.float))Bs = Bs.contiguous().to(torch.float)Cs = Cs.contiguous().to(torch.float)Ds = Ds.to(torch.float) # (K * c)delta_bias = dt_projs_bias.view(-1).to(torch.float)def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)ys: torch.Tensor = selective_scan(xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows,).view(B, K, -1, L)ori_h, ori_w = int(ori_h), int(ori_w)y = EfficientMerge.apply(ys, ori_h, ori_w, step_size) # [B, 4, C, H//w * W//w] -> [B, C, H*W]H = ori_hW = ori_wL = H * Wy = y.transpose(dim0=1, dim1=2).contiguous()y = out_norm(y).view(B, H, W, -1)return (y.to(x.dtype) if to_dtype else y)
3.1 其中,本文所提出的扫描策略EfficientScan.apply(x, step_size)
:
class EfficientScan(torch.autograd.Function):# [B, C, H, W] -> [B, 4, C, H * W] (original)# [B, C, H, W] -> [B, 4, C, H/w * W/w]@staticmethoddef forward(ctx, x: torch.Tensor, step_size=2): # [B, C, H, W] -> [B, 4, H/w * W/w]B, C, org_h, org_w = x.shapectx.shape = (B, C, org_h, org_w)ctx.step_size = step_sizeif org_w % step_size != 0:pad_w = step_size - org_w % step_sizex = F.pad(x, (0, pad_w, 0, 0)) W = x.shape[3]if org_h % step_size != 0:pad_h = step_size - org_h % step_sizex = F.pad(x, (0, 0, 0, pad_h))H = x.shape[2]H = H // step_sizeW = W // step_sizexs = x.new_empty((B, 4, C, H*W))xs[:, 0] = x[:, :, ::step_size, ::step_size].contiguous().view(B, C, -1)xs[:, 1] = x.transpose(dim0=2, dim1=3)[:, :, ::step_size, 1::step_size].contiguous().view(B, C, -1)xs[:, 2] = x[:, :, ::step_size, 1::step_size].contiguous().view(B, C, -1)xs[:, 3] = x.transpose(dim0=2, dim1=3)[:, :, 1::step_size, 1::step_size].contiguous().view(B, C, -1)xs = xs.view(B, 4, C, -1)return xs@staticmethoddef backward(ctx, grad_xs: torch.Tensor): # [B, 4, H/w * W/w] -> [B, C, H, W]B, C, org_h, org_w = ctx.shapestep_size = ctx.step_sizenewH, newW = math.ceil(org_h / step_size), math.ceil(org_w / step_size)grad_x = grad_xs.new_empty((B, C, newH * step_size, newW * step_size))grad_xs = grad_xs.view(B, 4, C, newH, newW)grad_x[:, :, ::step_size, ::step_size] = grad_xs[:, 0].reshape(B, C, newH, newW)grad_x[:, :, 1::step_size, ::step_size] = grad_xs[:, 1].reshape(B, C, newW, newH).transpose(dim0=2, dim1=3)grad_x[:, :, ::step_size, 1::step_size] = grad_xs[:, 2].reshape(B, C, newH, newW)grad_x[:, :, 1::step_size, 1::step_size] = grad_xs[:, 3].reshape(B, C, newW, newH).transpose(dim0=2, dim1=3)if org_h != grad_x.shape[-2] or org_w != grad_x.shape[-1]:grad_x = grad_x[:, :, :org_h, :org_w]return grad_x, None
EfficientScan
是一个继承自torch.autograd.Function
的自定义操作,其forward
方法将输入的 [B,C,H,W][B, C, H, W][B,C,H,W] 形状图像张量按指定步长step_size
进行分块处理(若尺寸不整除步长则先填充零),通过提取 444 种不同偏移量((0,0)、(1,0)、(0,1)、(1,1))((0,0)、(1,0)、(0,1)、(1,1))((0,0)、(1,0)、(0,1)、(1,1))的子序列,输出[B,4,C,H/w∗W/w][B, 4, C, H/w * W/w][B,4,C,H/w∗W/w] 形状的分块结果;backward
方法则将分块后的梯度张量 [B,4,C,H/w∗W/w][B, 4, C, H/w * W/w][B,4,C,H/w∗W/w] 按正向分块的逆过程还原为原始图像尺寸的梯度 [B,C,H,W][B, C, H, W][B,C,H,W],确保自动求导的正确性,整体用于高效处理图像张量的分块与梯度传递。
3.2 EfficientMerge.apply(ys, ori_h, ori_w, step_size)
具体实现:
class EfficientMerge(torch.autograd.Function): # [B, 4, C, H/w * W/w] -> [B, C, H*W]@staticmethoddef forward(ctx, ys: torch.Tensor, ori_h: int, ori_w: int, step_size=2):B, K, C, L = ys.shapeH, W = math.ceil(ori_h / step_size), math.ceil(ori_w / step_size)ctx.shape = (H, W)ctx.ori_h = ori_hctx.ori_w = ori_wctx.step_size = step_sizenew_h = H * step_sizenew_w = W * step_sizey = ys.new_empty((B, C, new_h, new_w))y[:, :, ::step_size, ::step_size] = ys[:, 0].reshape(B, C, H, W)y[:, :, 1::step_size, ::step_size] = ys[:, 1].reshape(B, C, W, H).transpose(dim0=2, dim1=3)y[:, :, ::step_size, 1::step_size] = ys[:, 2].reshape(B, C, H, W)y[:, :, 1::step_size, 1::step_size] = ys[:, 3].reshape(B, C, W, H).transpose(dim0=2, dim1=3)if ori_h != new_h or ori_w != new_w:y = y[:, :, :ori_h, :ori_w].contiguous()y = y.view(B, C, -1)return y@staticmethoddef backward(ctx, grad_x: torch.Tensor): # [B, C, H*W] -> [B, 4, C, H/w * W/w]H, W = ctx.shapeB, C, L = grad_x.shapestep_size = ctx.step_sizegrad_x = grad_x.view(B, C, ctx.ori_h, ctx.ori_w)if ctx.ori_w % step_size != 0:pad_w = step_size - ctx.ori_w % step_sizegrad_x = F.pad(grad_x, (0, pad_w, 0, 0)) W = grad_x.shape[3]if ctx.ori_h % step_size != 0:pad_h = step_size - ctx.ori_h % step_sizegrad_x = F.pad(grad_x, (0, 0, 0, pad_h))H = grad_x.shape[2]B, C, H, W = grad_x.shapeH = H // step_sizeW = W // step_sizegrad_xs = grad_x.new_empty((B, 4, C, H*W)) grad_xs[:, 0] = grad_x[:, :, ::step_size, ::step_size].reshape(B, C, -1) grad_xs[:, 1] = grad_x.transpose(dim0=2, dim1=3)[:, :, ::step_size, 1::step_size].reshape(B, C, -1)grad_xs[:, 2] = grad_x[:, :, ::step_size, 1::step_size].reshape(B, C, -1)grad_xs[:, 3] = grad_x.transpose(dim0=2, dim1=3)[:, :, 1::step_size, 1::step_size].reshape(B, C, -1)return grad_xs, None, None, None
EfficientMerge
是一个自定义的 PyTorch autograd.Function
,它的作用是把四个不同方向或偏移的扫描分支特征在前向过程中整合成一张完整的特征图,并在反向过程中正确地将梯度拆解回四个分支,保证端到端的可训练性。
- 在 前向传播 中,它接收形状 [B,4,C,L][B, 4, C, L][B,4,C,L] 的输入,其中 BBB 是 batchbatchbatch 大小,4 表示四个扫描方向,CCC 是通道数,LLL 是展平后的空间长度。
- 首先根据
step_size
推算出下采样后的空间高宽(H, W)
,然后把四个分支的特征交错式地填充到一张新的特征图上:- 第 0 分支放在
(0::step_size, 0::step_size)
的位置; - 第 1 分支放在
(1::step_size, 0::step_size)
,并在放置前做 H/WH/WH/W 转置; - 第 2 分支放在
(0::step_size, 1::step_size)
; - 第 3 分支放在
(1::step_size, 1::step_size)
,同样需要转置。
- 第 0 分支放在
- 这种交错填充相当于把原图划分为一个棋盘格式的网格,每个分支负责其中一种位置的子格,从而拼出完整的高分辨率特征图。
- 如果拼接后的尺寸超过原始输入大小,就会进行裁剪,最后将其展平成
[B, C, H*W]
以便后续计算。
- 首先根据
- 在 反向传播 中,函数会把
[B, C, H*W]
的梯度reshape 回 [B, C, H, W]
,再根据前向的checkerboard
规则,把对应位置的梯度切分回四个分支:分别对应(0,0)、(1,0)、(0,1)、(1,1)
的交错位置,并对需要转置的分支执行逆转置操作,同时考虑到输入高宽可能不是step_size
的整数倍,还会自动对梯度进行padding
。最终得到[B, 4, C, H*W]
的梯度,与输入保持一致。 - 整体上,
EfficientMerge
就是一个专门设计的高效拼合/拆分模块,它在前向把多方向扫描特征拼接成原图结构,在反向保证每个分支都能收到正确的梯度,实现了多路径扫描和全局图像特征之间的桥梁。