当前位置: 首页 > news >正文

深度学习·经典模型·SwinTransformer

SwinTransformer

  • 主要创新点:移动窗口,基于窗口的注意力计算

Patch Embedding

  • 下采样打包为Pacth:可以直接使用Conv2d

  • 也可以先打包后使用embedding映射。

Patch Merging

  • 类似池化的操作,压缩图片大小,同时通道数增多,获得更多的语义信息。

  • 实现:获得相邻的Patch,然后在通道维度上concat,维度变为 4 C 4C 4C,最后经过线性层投射回 2 C 2C 2C

例子

  • 1,2,5,6是相邻的2x2的Patch
[	[1,2,3,4][5,6,7,8][9,10,11,12][13,14,15,16]]
  • 通过切片获得对应位置的元素
  • 注意我们从通道维度上拼接,所以不能按照传统的上下拼接的思路理解
  • 这段代码的效果是:编号1,2,5,6的特征向量拼接,(相邻元素就好像叠加在一起)
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]

窗口注意力机制

  • 本文使用了窗口注意力机制,计算复杂度是随着图像扩大线性增长的

  • 这点就好比 8 ∗ 8 = 64 > > > 4 ∗ ( 2 ∗ 2 ) = 16 8*8=64>>>4*(2*2)=16 88=64>>>4(22)=16),小窗口 2 ∗ 2 2*2 22比大窗口 8 ∗ 8 8*8 88明显复杂度低很多。

实现讲解

  • 输入: ( B ∗ N w , M h ∗ M w , C ) (B*N_w,M_h*M_w,C) (BNw,MhMw,C) N w N_w Nw是窗口数量,可以参加Embedding层 N w = H ∗ W M h ∗ M w N_w=\frac{H*W}{M_h*M_w} Nw=MhMwHW
  • 输入的理解:将窗口数量理解为一种批次 M h ∗ M w M_h*M_w MhMw作为序列的长度,reshape为指定维度: ( B ∗ N w , M h ∗ N w , C ) (B*N_w,M_h*N_w,C) (BNw,MhNw,C)
  • 快速计算KQV,直接使用线性层映射 ( B ∗ N w , M h ∗ N w , 3 C ) (B*N_w,M_h*N_w,3C) (BNw,MhNw,3C),然后拆分最后一个维度 3 C 3C 3C,变成各自 ( 3 , B ∗ N w , M h ∗ N w , C ) (3,B*N_w,M_h*N_w,C) (3BNw,MhNw,C)的QKV大小,为分离QKV作准备。
  • 多头注意力机制:每一个KQV维度 ( 3 , B ∗ N w , M h ∗ N w , C ) (3,B*N_w,M_h*N_w,C) (3,BNw,MhNw,C),转换为 ( 3 , B ∗ N w , N h e a d , M h ∗ N w , d i m h e a d ) (3,B*N_w,N_{head},M_h*N_w,dim_{head}) (3BNw,Nhead,MhNw,dimhead), N h e a d N_{head} Nhead不会参与计算,只需要最后两个维度进行KQV的矩阵乘法即可获得最终的多头注意力输出!
  • 然后就是Masked掩码操作:这里使用的是加性掩码,掩码的生成方式见下。
  • 输出维度不变: ( B ∗ N w , M h ∗ M w , C ) (B*N_w,M_h*M_w,C) (BNw,MhMw,C)
    def forward(self, x, mask: Optional[torch.Tensor] = None):"""Args:x: input features with shape of (num_windows*B, Mh*Mw, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""# [batch_size*num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]q = q * self.scaleattn = (q @ k.transpose(-2, -1))# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:# mask: [nW, Mh*Mw, Mh*Mw]nW = mask.shape[0]  # num_windows# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]# mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x

滑动窗口的实现

  • 本文的核心操作:实现起来不难
  • 实现代码:注意图像整体往右下,roll这个函数是相当于移动窗口的,所以是往左上移动窗口
  • 输入和输出是以图片的格式: ( B , H ∗ W , C ) (B,H*W,C) (B,HW,C)
        if self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

移动窗口注意力

  • 先调用移动窗口:对图像进行移动处理。
  • 使用被移动后的图像进行窗口注意力计算,输出维度 ( B ∗ N w , M h ∗ M w , C ) (B*N_w,M_h*M_w,C) (BNw,MhMw,C)
  • 还原为图像 ( B , H , W , C ) (B,H,W,C) (B,H,W,C)
  • 反方向移动图像:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  • reshape为: ( B , H ∗ W , C ) (B, H * W, C) (B,HW,C),丢入MLP中处理,放大 4 C 4C 4C,然后还原为 C C C

在这里插入图片描述

在这里插入图片描述

MASK的实现

  • 建议直接抄以下代码:
    def create_mask(self, x, H, W):# calculate attention mask for SW-MSA# 保证Hp和Wp是window_size的整数倍Hp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partitionimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]# [nW, Mh*Mw, Mh*Mw]attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_mask

在这里插入图片描述

相关文章:

  • C语言教程(二十三):C 语言强制类型转换详解
  • C++核心编程 1.2 程序运行后
  • 【阿里云大模型高级工程师ACP习题集】2.7 通过微调增强模型能力 (上篇)(⭐️⭐️⭐️ 重点章节!!!)
  • 什么是缓冲区溢出?NGINX是如何防止缓冲区溢出攻击的?
  • LangChain4j +DeepSeek大模型应用开发——5 持久化聊天记忆 Persistence
  • Linux 命名管道+日志
  • 微信小程序开发,购物商城实现
  • 阿里通义Qwen3:双引擎混合推理,119语言破局全球AI竞赛
  • Golang 并发编程
  • 厚铜PCB钻孔工艺全解析:从参数设置到孔壁质量的关键控制点
  • Sql刷题日志(day7)
  • BG开发者日志429:故事模式的思路
  • 免费超好用的电脑操控局域网内的手机(多台,无线)
  • 开放平台架构方案- GraphQL 详细解释
  • 信息系统项目管理工程师备考计算类真题讲解十一
  • 为什么业务总是被攻击?使用游戏盾解决方案
  • 通过全局交叉注意力机制和距离感知训练从多模态数据中识别桥本氏甲状腺炎|文献速递-深度学习医疗AI最新文献
  • 生物信息学常用软件InSequence,3大核心功能,简易好上手
  • 雅思口语高频词汇表达
  • 深度学习篇---模型权重变化与维度分析
  • 北大深圳研究生院成立科学智能学院:培养交叉复合型人才
  • 日本希望再次租借大熊猫,外交部:双方就相关合作保持密切沟通
  • 马上评丨又见酒店坐地起价,“老毛病”不能惯着
  • “自己生病却让别人吃药”——抹黑中国经济解决不了美国自身问题
  • 报告显示2024年全球军费开支增幅达冷战后最大
  • 修订占比近30%收录25万条目,第三版《英汉大词典》来了