MAP的具体实现
文章目录
- MAP的具体实现
- 编码器
- 解码器
- 损失函数
MAP的具体实现
博客:CVPR | 2025 | MAP:通过掩码自回归预训练释放混合 Mamba - Transformer 视觉骨干网络的潜力
- 论文:https://arxiv.org/pdf/2410.00871
- 代码:https://github.com/yunzeliu/MAP
- (代码)镜像:https://gitee.com/apuppyliu-cong/MAP.git
- 会议:CVPR
- 年份:2025
编码器
def forward_mae_encoder(self, x, mask):x = self.patch_embed(x)B, M, C = x.shapebsz, seq_len, embed_dim = x.shapex = x + self.pos_embed# droppingx = x[(1 - mask).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)# mamba implresidual = Nonehidden_states = xfor n, layer in enumerate(self.layers):hidden_states, residual = layer(hidden_states, residual)if (n + 1) % 3 == 0:block_idx = n // 3 if block_idx < len(self.blocks):block_residual = hidden_states hidden_states = self.blocks[block_idx](hidden_states)hidden_states += block_residualif not self.fused_add_norm:if residual is None:residual = hidden_stateselse:residual = residual + self.drop_path(hidden_states)hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))else:# Set prenorm=False here since we don't need the residualfused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fnhidden_states = fused_add_norm_fn(self.drop_path(hidden_states),self.norm_f.weight,self.norm_f.bias,eps=self.norm_f.eps,residual=residual,prenorm=False,residual_in_fp32=self.residual_in_fp32,)return hidden_states
按流程:
- 首先,输入张量
x [B,C,H,W]
经过x = self.patch_embed(x)
实现维度嵌入,其具体实现如下:
self.patch_embed = PatchEmbed(img_size=img_size,patch_size=patch_size,in_chans=channels,embed_dim=embed_dim)# 其中,img_size=224,patch_size=16,channels=3,embed_dim=192。PatchEmbed的具体实现如下
class PatchEmbed(nn.Module):""" 2D Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,norm_layer=None, flatten=True):super().__init__()img_size = to_2tuple(img_size)patch_size = to_2tuple(patch_size)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = ((img_size[0] - patch_size[0]) // patch_size[0] + 1, (img_size[1] - patch_size[1]) // patch_size[1] + 1)self.num_patches = self.grid_size[0] * self.grid_size[1]self.flatten = flattenself.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."x = self.proj(x)if self.flatten:x = x.flatten(2).transpose(1, 2) # BCHW -> BNCx = self.norm(x)return x
经过这一步,输入的图像由原来的 [ B ,3 ,224 ,224 ] 经过一个2D投影头x = self.proj(x)
以后变为 [B, 192,14,14]。
由于self.flatten = flatten
而且flatten=True
,于是需要将张量x展平,形状由[B, 192,14,14]
变为[B,192,196]
并且交换通道维度[B,196,192]
。
由于最后并没有进行归一化(norm_layer=None
而且self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
),所以最终输出的张量形状就是[B,196,192]
- 第二,引入位置编码
x = x + self.pos_embed
#------------------------------------------------------------------------------
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
此步骤x = x + self.pos_embed
中,张量形状不发生变化[B,196,192]
- 第三,获取可见序列
x = x[(1 - mask).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
该过程中,根据传入的mask
与加入位置编码以后的x
,模型会自动索引未被掩码掉的图像块x[(1 - mask).nonzero(as_tuple=True)]
,并将其重新排列成[B,len_keep,192]
→\to→ 3.1 获取掩码矩阵。
# patchify and mask (drop) tokensx = self.patchify(imgs)orders = self.sample_orders(bsz=x.size(0))mask = self.random_masking(x, orders)
#--------------------------------------------------------------------------------def patchify(self, imgs):"""imgs: (N, 3, H, W)x: (N, L, patch_size**2 *3)"""p = self.patch_embed.patch_size[0]assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0h = w = imgs.shape[2] // px = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))return x
#---------------------------------------------------------------------------------def sample_orders(self, bsz):# generate a batch of random generation ordersorders = []for _ in range(bsz):order = np.array(list(range(self.seq_len)))np.random.shuffle(order)orders.append(order)orders = torch.Tensor(np.array(orders)).cuda().long()return ordersdef random_masking(self, x, orders):# generate token maskbsz, seq_len, embed_dim = x.shapemask_rate = self.mask_ratio_generator.rvs(1)[0]num_masked_tokens = int(np.ceil(seq_len * mask_rate))mask = torch.zeros(bsz, seq_len, device=x.device)mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],src=torch.ones(bsz, seq_len, device=x.device))return mask
#------------------------------------------------------------------------------
'''mask_ratio_max = 0.7,mask_ratio_min = 0.3,scale = 0.25,loc = (mask_ratio_max + mask_ratio_min) / 2a = (mask_ratio_min - loc) / scaleb = (mask_ratio_max - loc) / scaleself.mask_ratio_generator = stats.truncnorm(a, b, loc=loc, scale=scale)
'''
对于输入的完整图像x = [B,3,224,224]
,首先对其进行图像令牌划分x = self.patchify(imgs)
操作。最终输入张量的形状变成[B,196,768]
然后将划分好的图像令牌序列,送入排序模块进行排序,并进行随机掩码。随机掩码过程中由于mask_rate = self.mask_ratio_generator.rvs(1)[0]
且前面定义的参数a=-0.8、b=0.8、loc=0.5、scale=0.25
这就意味着,mask_rate
并非是一个定值,而是在[0.3,0.7]
之间浮动
- layer
self.blocks = nn.ModuleList([Block_vit(embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio/2, qkv_bias=True, norm_layer=nn.LayerNorm)for i in range(int(depth//3))])
#-------------------------------------------------------------------------------------residual = Nonehidden_states = xfor n, layer in enumerate(self.layers):hidden_states, residual = layer(hidden_states, residual)if (n + 1) % 3 == 0:block_idx = n // 3 if block_idx < len(self.blocks):block_residual = hidden_states hidden_states = self.blocks[block_idx](hidden_states)hidden_states += block_residual
#-------------------------------------------------------------------------------------self.layers = nn.ModuleList([create_block(embed_dim,ssm_cfg=ssm_cfg,norm_epsilon=norm_epsilon,rms_norm=rms_norm,residual_in_fp32=residual_in_fp32,fused_add_norm=fused_add_norm,layer_idx=i,drop_path=inter_dpr[i],**factory_kwargs,)for i in range(depth)])
#-------------------------------------------------------------------------------------
def create_block(d_model,ssm_cfg=None,norm_epsilon=1e-5,drop_path=0.,rms_norm=False,residual_in_fp32=False,fused_add_norm=False,layer_idx=None,device=None,dtype=None,
):if ssm_cfg is None:ssm_cfg = {}factory_kwargs = {"device": device, "dtype": dtype}mixer_cls = partial(Mamba,layer_idx=layer_idx,biscan=True,**ssm_cfg,**factory_kwargs)norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)block = Block(d_model, mixer_cls,norm_cls=norm_cls,drop_path=drop_path,fused_add_norm=fused_add_norm,residual_in_fp32=residual_in_fp32, )block.layer_idx = layer_idxreturn block
#-------------------------------------------------------------------------------------
class Block(nn.Module):def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0.):"""Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"This Block has a slightly different structure compared to a regularprenorm Transformer block.The standard block is: LN -> MHA/MLP -> Add.[Ref: https://arxiv.org/abs/2002.04745]Here we have: Add -> LN -> Mixer, returning boththe hidden_states (output of the mixer) and the residual.This is purely for performance reasons, as we can fuse add and LayerNorm.The residual needs to be provided (except for the very first block)."""super().__init__()self.residual_in_fp32 = residual_in_fp32self.fused_add_norm = fused_add_normself.mixer = mixer_cls(dim)self.norm = norm_cls(dim)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()if self.fused_add_norm:assert RMSNorm is not None, "RMSNorm import fails"assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):r"""Pass the input through the encoder layer.Args:hidden_states: the sequence to the encoder layer (required).residual: hidden_states = Mixer(LN(residual))"""if not self.fused_add_norm:if residual is None:residual = hidden_stateselse:residual = residual + self.drop_path(hidden_states)hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))if self.residual_in_fp32:residual = residual.to(torch.float32)else:fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fnif residual is None:hidden_states, residual = fused_add_norm_fn(hidden_states,self.norm.weight,self.norm.bias,residual=residual,prenorm=True,residual_in_fp32=self.residual_in_fp32,eps=self.norm.eps,)else:hidden_states, residual = fused_add_norm_fn(self.drop_path(hidden_states),self.norm.weight,self.norm.bias,residual=residual,prenorm=True,residual_in_fp32=self.residual_in_fp32,eps=self.norm.eps,)hidden_states = self.mixer(hidden_states, inference_params=inference_params)return hidden_states, residual
#-------------------------------------------------------------------------------------# depth=24,self.blocks = nn.ModuleList([Block_vit(embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio/2, qkv_bias=True, norm_layer=nn.LayerNorm)for i in range(int(depth//3))])
一步步进行拆解。首先,定义了residual = None
即,残差一开始是空(None
)。而且隐藏状态hidden_states = x
一开始是x
。接下来残差residual
与隐藏状态hidden_states
将作为输入张量送入layer
中。
-
layer
的结构:
如图所示(论文介绍见 CVPR | 2025 | MAP:通过掩码自回归预训练释放混合 Mamba - Transformer 视觉骨干网络的潜力)。最终 MAP 选用了MMMTMMMT结构,在代码中展示为条件判断if (n + 1) % 3 == 0
,那么块idblock_idx = n // 3
。(例如,如果当前是第3层,那么就当前block_idx=1)- 由代码中
depth = 24
可知其中大部分层是create_block
生成的Mamba Block
记作self.layers
。 - 另外构建了
self.blocks
,包含depth//3 = 8
个ViT Block.
- 由代码中
因此整体逻辑是,模型每三层MambaBlock之后插入一次ViTBlock。大致上层结构类似于:
层结构 = [MambaBlock1, MambaBlock2, MambaBlock3 + ViTBlock1,MambaBlock4, MambaBlock5, MambaBlock6 + ViTBlock2,...MambaBlock22, MambaBlock23, MambaBlock24 + ViTBlock8]
显然可以看出,这与上图是相符合的。
→\to→ fused_add_norm()
fused_add_norm=False #_init_self.fused_add_norm = fused_add_norm
#------------------------------------------------------------------------------------------------if not self.fused_add_norm:if residual is None:residual = hidden_stateselse:residual = residual + self.drop_path(hidden_states)hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))else:# Set prenorm=False here since we don't need the residualfused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fnhidden_states = fused_add_norm_fn(self.drop_path(hidden_states),self.norm_f.weight,self.norm_f.bias,eps=self.norm_f.eps,residual=residual,prenorm=False,residual_in_fp32=self.residual_in_fp32,)
在 整个 Encoder
堆叠的最后,再做一次 Add + Norm
截止此处已经将MAP的编码器详细说明。
解码器
def forward_mae_decoder(self, x, mask):x = self.decoder_embed(x)# pad mask tokensmask_tokens = self.mask_token.repeat(mask.shape[0], mask.shape[1], 1).to(x.dtype)x_after_pad = mask_tokens.clone()x_after_pad[(1 - mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])# decoder position embeddingx = x_after_pad + self.decoder_pos_embed_learnedB, seq_len, C = x.shape tokens_per_row = int(seq_len ** 0.5)assert tokens_per_row ** 2 == seq_len, "seq_len is not a perfect square!"mask = torch.tril(torch.ones((tokens_per_row , tokens_per_row ), dtype=torch.float)).to(x.device)mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0)mask = torch.repeat_interleave(mask, repeats=tokens_per_row, dim=0)mask = torch.repeat_interleave(mask, repeats=tokens_per_row, dim=1)for block in self.decoder_blocks:x = block(x, x, mask=mask)x = self.decoder_norm(x)x = self.ar_pred(x)return x
x = self.decoder_embed(x)
首先将编码器的输出(hidden_states
)通过线性投影nn.Linear
映射到解码器接受的维度decoder_embed_dim
# embed_dim=192
# decoder_embed_dim = 192self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
- 关于掩码部分
→\to→ 2.1 准备阶段
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
#------------------------------------------------------------------------------mask_tokens = self.mask_token.repeat(mask.shape[0], mask.shape[1], 1).to(x.dtype)x_after_pad = mask_tokens.clone()x_after_pad[(1 - mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
在此阶段,首先mask
矩阵的形状是[N,L]
经过mask_tokens = self.mask_token.repeat(mask.shape[0], mask.shape[1], 1).to(x.dtype)
以后得到mask_tokens
的形状是[B,L,decoder_embed_dim]
接下来,将mask_tokens
克隆以获得一个可写的副本
下一步,左侧x_after_pad[(1 - mask).nonzero(as_tuple=True)]
。首先可以知道的是,掩码矩阵mask
中原本代表0
的位置是将该位置的图像块掩码掉,而1
是讲该位置的图像块保留。那么可见位置的总数,可见 token 总数 = B * M
而右侧x
是编码器输出的可见 token
,形状 [B, M, C]
,其中M = 每张图可见 token 数
。那么在这里起到的作用其实是,将原本张量x
中的课件图像块分量值一次性地赋值给创建的x_after_pad
里面的“未被掩码”的位置。
→\to→ 2.2 加入可学习的解码器位置编码
num_patches = self.patch_embed.num_patchesself.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim))
#------------------------------------------------------------------------------x = x_after_pad + self.decoder_pos_embed_learnedB, seq_len, C = x.shape tokens_per_row = int(seq_len ** 0.5)assert tokens_per_row ** 2 == seq_len, "seq_len is not a perfect square!"
→\to→ 2.3
mask = torch.tril(torch.ones((tokens_per_row , tokens_per_row ), dtype=torch.float)).to(x.device)mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0)mask = torch.repeat_interleave(mask, repeats=tokens_per_row, dim=0)mask = torch.repeat_interleave(mask, repeats=tokens_per_row, dim=1)
首先 mask = torch.tril(torch.ones((tokens_per_row , tokens_per_row ), dtype=torch.float)).to(x.device)
生成一个下三角矩阵
然后将下三角矩阵转换成注意力掩码的形式mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0)
将 0 → -inf
(禁止注意力)、将 1 → 0
(允许注意力)
这是 Transformer 中标准的 attention mask 表示:
- 在
softmax
前加上mask:score + mask
-inf
的位置softmax = 0
,不能关注
mask = torch.repeat_interleave(mask, repeats=tokens_per_row, dim=0)
mask = torch.repeat_interleave(mask, repeats=tokens_per_row, dim=1)
原来 mask
只是 每行/列的粗略 token
局部关系,大小 [tokens_per_row, tokens_per_row]
repeat_interleave(..., dim=0)
→ 将每行重复 tokens_per_row
次
repeat_interleave(..., dim=1)
→ 将每列重复 tokens_per_row
次
decoder_blocks
class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x
#------------------------------------------------------------------------------
class CrossAttention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_heads# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weightsself.scale = qk_scale or head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, q, kv, mask):B, N, C = q.shapeq = self.q(q).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)kv = self.kv(kv).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple)attn = (q @ k.transpose(-2, -1)) * self.scaleattn += maskattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return x
#------------------------------------------------------------------------------
class DecoderBlock(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.attn2 = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.norm2_1 = norm_layer(dim)self.norm2_2 = norm_layer(dim)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout hereself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, q, kv, mask):q = q + self.attn2(self.norm2_1(q), self.norm2_2(kv), mask)q = q + self.mlp(self.norm2(q))return q
#------------------------------------------------------------------------------self.decoder_blocks = nn.ModuleList([DecoderBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=nn.LayerNorm)for _ in range(4)])
#------------------------------------------------------------------------------for block in self.decoder_blocks:x = block(x, x, mask=mask)
x
作为 q
和 kv
同时输入,做自注意力
mask
控制每个 token
只能看到特定位置
每经过一个 block
,x
的特征都会被 cross-attention + MLP
更新
循环 4
次 → 4
层 decoder
处理,得到最终 token
表示
x = self.decoder_norm(x)x = self.ar_pred(x)return x
如此便做成了decoder
的输出。
损失函数
如下所示,采用了经典的MSE损失,与原版MAE相同。
def forward_loss_mae(self, imgs, pred, mask):"""imgs: [N, 3, H, W]pred: [N, L, p*p*3]mask: [N, L], 0 is keep, 1 is remove,"""norm_pix_loss = Truetarget = self.patchify(imgs)if norm_pix_loss:mean = target.mean(dim=-1, keepdim=True)var = target.var(dim=-1, keepdim=True)target = (target - mean) / (var + 1.e-6) ** .5loss = (pred - target) ** 2loss = loss.mean(dim=-1) # [N, L], mean loss per patchloss = (loss * mask).sum() / mask.sum() # mean loss on removed patchesreturn loss