RTDETR融合[CVPR2024]SHViT中的SHSA模块
RT-DETR使用教程: RT-DETR使用教程
RT-DETR改进汇总贴:RT-DETR更新汇总贴
《SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design》
一、 模块介绍
论文链接:https://arxiv.org/abs/2401.16456
代码链接:https://github.com/ysj9909/SHViT/blob/main/model/shvit.py
论文速览:
高效的 Vision Transformer 在资源受限的设备上表现出出色的性能和低延迟。传统上,它们在宏观层面使用 4x4 块状掩码和 4 级结构,在微观层面利用复杂的注意力和多头配置。本文旨在以内存高效的方式解决所有设计级别的计算冗余。我们发现,使用更大步长的块状掩码,不仅降低了内存访问成本,而且通过利用标记表示从早期阶段就减少了空间冗余,实现了具有竞争力的性能。此外,我们的初步分析表明,早期阶段的注意层可以用卷积代替,而后期阶段的几个注意头在计算上是冗余的。为了解决这个问题,我们引入了一个单头注意力模块,它固有地防止了头部冗余,同时通过并行结合全局和局部信息来提高准确性。在我们的解决方案的基础上,我们推出了单头视觉Transformer SHViT,它获得了最先进的速度和精度权衡。例如,在ImageNet-1k上,我们的shviti - s4在GPU、CPU和iPhone12移动设备上分别比MobileViTv2 x1.0快3.3倍、8.1倍和2.4倍,准确率提高1.3%。对于使用Mask-RCNN头的MS COCO对象检测和实例分割,我们的模型在GPU和移动设备上分别表现出3.8倍和2.0倍的骨干延迟,性能可与fastviti - sa12相媲美。
总结:本文更新其中的SHSA模块代码及使用方法。
⭐⭐本文二创模块仅更新于付费群中,往期免费教程可看下方链接⭐⭐
RT-DETR更新汇总贴(含免费教程)文章浏览阅读264次。RT-DETR使用教程:缝合教程: RT-DETR中的yaml文件详解:labelimg使用教程:_rt-deterhttps://xy2668825911.blog.csdn.net/article/details/143696113
二、二创融合模块
2.1 相关代码
import torch
import itertoolsfrom timm.models.vision_transformer import trunc_normal_
from timm.models.layers import SqueezeExciteclass GroupNorm(torch.nn.GroupNorm):"""Group Normalization with 1 group.Input: tensor in shape [B, C, H, W]"""def __init__(self, num_channels, **kwargs):super().__init__(1, num_channels, **kwargs)class Conv2d_BN(torch.nn.Sequential):def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,groups=1, bn_weight_init=1):super().__init__()self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))self.add_module('bn', torch.nn.BatchNorm2d(b))torch.nn.init.constant_(self.bn.weight, bn_weight_init)torch.nn.init.constant_(self.bn.bias, 0)@torch.no_grad()def fuse(self):c, bn = self._modules.values()w = bn.weight / (bn.running_var + bn.eps)**0.5w = c.weight * w[:, None, None, None]b = bn.bias - bn.running_mean * bn.weight / \(bn.running_var + bn.eps)**0.5m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,device=c.weight.device)m.weight.data.copy_(w)m.bias.data.copy_(b)return mclass BN_Linear(torch.nn.Sequential):def __init__(self, a, b, bias=True, std=0.02):super().__init__()self.add_module('bn', torch.nn.BatchNorm1d(a))self.add_module('l', torch.nn.Linear(a, b, bias=bias))trunc_normal_(self.l.weight, std=std)if bias:torch.nn.init.constant_(self.l.bias, 0)@torch.no_grad()def fuse(self):bn, l = self._modules.values()w = bn.weight / (bn.running_var + bn.eps)**0.5b = bn.bias - self.bn.running_mean * \self.bn.weight / (bn.running_var + bn.eps)**0.5w = l.weight * w[None, :]if l.bias is None:b = b @ self.l.weight.Telse:b = (l.weight @ b[:, None]).view(-1) + self.l.biasm = torch.nn.Linear(w.size(1), w.size(0))m.weight.data.copy_(w)m.bias.data.copy_(b)return mclass PatchMerging(torch.nn.Module):def __init__(self, dim, out_dim):super().__init__()hid_dim = int(dim * 4)self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0)self.act = torch.nn.ReLU()self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)self.se = SqueezeExcite(hid_dim, .25)self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0)def forward(self, x):x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))return xclass Residual(torch.nn.Module):def __init__(self, m, drop=0.):super().__init__()self.m = mself.drop = dropdef forward(self, x):if self.training and self.drop > 0:return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,device=x.device).ge_(self.drop).div(1 - self.drop).detach()else:return x + self.m(x)@torch.no_grad()def fuse(self):if isinstance(self.m, Conv2d_BN):m = self.m.fuse()assert(m.groups == m.in_channels)identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)identity = torch.nn.functional.pad(identity, [1,1,1,1])m.weight += identity.to(m.weight.device)return melse:return selfclass FFN(torch.nn.Module):def __init__(self, ed, h):super().__init__()self.pw1 = Conv2d_BN(ed, h)self.act = torch.nn.ReLU()self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0)def forward(self, x):x = self.pw2(self.act(self.pw1(x)))return xclass SHSA(torch.nn.Module):"""Single-Head Self-Attention"""def __init__(self, dim, qk_dim, pdim):super().__init__()self.scale = qk_dim ** -0.5self.qk_dim = qk_dimself.dim = dimself.pdim = pdimself.pre_norm = GroupNorm(pdim)self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(dim, dim, bn_weight_init = 0))def forward(self, x):B, C, H, W = x.shapex1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim = 1)x1 = self.pre_norm(x1)qkv = self.qkv(x1)q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim = 1)q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)attn = (q.transpose(-2, -1) @ k) * self.scaleattn = attn.softmax(dim = -1)x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)x = self.proj(torch.cat([x1, x2], dim = 1))return xclass BasicBlock(torch.nn.Module):def __init__(self, dim, qk_dim, pdim, type):super().__init__()if type == "s": # for later stagesself.conv = Residual(Conv2d_BN(dim, dim, 3, 1, 1, groups = dim, bn_weight_init = 0))self.mixer = Residual(SHSA(dim, qk_dim, pdim))self.ffn = Residual(FFN(dim, int(dim * 2)))elif type == "i": # for early stagesself.conv = Residual(Conv2d_BN(dim, dim, 3, 1, 1, groups = dim, bn_weight_init = 0))self.mixer = torch.nn.Identity()self.ffn = Residual(FFN(dim, int(dim * 2)))def forward(self, x):return self.ffn(self.mixer(self.conv(x)))class SHViT(torch.nn.Module):def __init__(self,in_chans=3,num_classes=1000,embed_dim=[128, 256, 384],partial_dim = [32, 64, 96],qk_dim=[16, 16, 16],depth=[1, 2, 3],types = ["s", "s", "s"],down_ops=[['subsample', 2], ['subsample', 2], ['']],distillation=False,):super().__init__()# Patch embeddingself.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), torch.nn.ReLU(),Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), torch.nn.ReLU(),Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), torch.nn.ReLU(),Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1))self.blocks1 = []self.blocks2 = []self.blocks3 = []# Build SHViT blocksfor i, (ed, kd, pd, dpth, do, t) in enumerate(zip(embed_dim, qk_dim, partial_dim, depth, down_ops, types)):for d in range(dpth):eval('self.blocks' + str(i+1)).append(BasicBlock(ed, kd, pd, t))if do[0] == 'subsample':# Build SHViT downsample block#('Subsample' stride)blk = eval('self.blocks' + str(i+2))blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))),))blk.append(PatchMerging(*embed_dim[i:i + 2]))blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))),))self.blocks1 = torch.nn.Sequential(*self.blocks1)self.blocks2 = torch.nn.Sequential(*self.blocks2)self.blocks3 = torch.nn.Sequential(*self.blocks3)# Classification headself.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()self.distillation = distillationif distillation:self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()def forward(self, x):x = self.patch_embed(x)x = self.blocks1(x)x = self.blocks2(x)x = self.blocks3(x)x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)if self.distillation:x = self.head(x), self.head_dist(x)if not self.training:x = (x[0] + x[1]) / 2else:x = self.head(x)return x
2.2 更改yaml文件 (以自研模型加入为例)
yam文件解读:YOLO系列 “.yaml“文件解读_yolo yaml文件-CSDN博客
打开更改ultralytics/cfg/models/rt-detr路径下的rtdetr-l.yaml文件,替换原有模块。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# ⭐⭐Powered by https://blog.csdn.net/StopAndGoyyy, 技术指导QQ:2668825911⭐⭐# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'# [depth, width, max_channels]l: [1.00, 1.00, 512]
# n: [ 0.33, 0.25, 1024 ]
# s: [ 0.33, 0.50, 1024 ]
# m: [ 0.67, 0.75, 768 ]
# l: [ 1.00, 1.00, 512 ]
# x: [ 1.00, 1.25, 512 ]
# ⭐⭐Powered by https://blog.csdn.net/StopAndGoyyy, 技术指导QQ:2668825911⭐⭐backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 2, CCRI, [128, 5, True, False]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 1, SHSA, []]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 4, CCRI, [512, 3, True, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 2, CCRI, [1024, 3, True, False]]head:- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 input_proj.2- [-1, 1, AIFI, [1024, 8]]- [-1, 1, Conv, [256, 1, 1]] # 11, Y5, lateral_convs.0- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [6, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 13 input_proj.1- [[-2, -1], 1, Concat, [1]]- [-1, 2, RepC4, [256]] # 15, fpn_blocks.0- [-1, 1, Conv, [256, 1, 1]] # 16, Y4, lateral_convs.1- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [4, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 18 input_proj.0- [[-2, -1], 1, Concat, [1]] # cat backbone P4- [-1, 2, RepC4, [256]] # X3 (20), fpn_blocks.1- [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0- [[-1, 16], 1, Concat, [1]] # cat Y4- [-1, 2, RepC4, [256]] # F4 (23), pan_blocks.0- [-1, 1, Conv, [256, 3, 2]] # 24, downsample_convs.1- [[-1, 11], 1, Concat, [1]] # cat Y5- [-1, 2, RepC4, [256]] # F5 (26), pan_blocks.1- [[20, 23, 26], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
# ⭐⭐Powered by https://blog.csdn.net/StopAndGoyyy, 技术指导QQ:2668825911⭐⭐
2.2 修改train.py文件
创建Train_RT脚本用于训练。
from ultralytics.models import RTDETR
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'if __name__ == '__main__':model = RTDETR(model='ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')# model.load('yolov8n.pt')model.train(data='./data.yaml', epochs=2, batch=1, device='0', imgsz=640, workers=2, cache=False,amp=True, mosaic=False, project='runs/train', name='exp')
在train.py脚本中填入修改好的yaml路径,运行即可训。