当前位置: 首页 > news >正文

YOLOV11改进策略【最新注意力机制】CVPR2025局部区域注意力机制LRSA-增强局部区域特征之间的交互

1.1网络结构

1.2 添加过程

1.2.1 核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrangedef patch_divide(x, step, ps):"""Crop image into patches.Args:x (Tensor): Input feature map of shape(b, c, h, w).step (int): Divide step.ps (int): Patch size.Returns:crop_x (Tensor): Cropped patches.nh (int): Number of patches along the horizontal direction.nw (int): Number of patches along the vertical direction."""b, c, h, w = x.size()if h == ps and w == ps:step = pscrop_x = []nh = 0for i in range(0, h + step - ps, step):top = idown = i + psif down > h:top = h - psdown = hnh += 1for j in range(0, w + step - ps, step):left = jright = j + psif right > w:left = w - psright = wcrop_x.append(x[:, :, top:down, left:right])nw = len(crop_x) // nhcrop_x = torch.stack(crop_x, dim=0)  # (n, b, c, ps, ps)crop_x = crop_x.permute(1, 0, 2, 3, 4).contiguous()  # (b, n, c, ps, ps)return crop_x, nh, nwdef patch_reverse(crop_x, x, step, ps):"""Reverse patches into image.Args:crop_x (Tensor): Cropped patches.x (Tensor): Feature map of shape(b, c, h, w).step (int): Divide step.ps (int): Patch size.Returns:ouput (Tensor): Reversed image."""b, c, h, w = x.size()output = torch.zeros_like(x)index = 0for i in range(0, h + step - ps, step):top = idown = i + psif down > h:top = h - psdown = hfor j in range(0, w + step - ps, step):left = jright = j + psif right > w:left = w - psright = woutput[:, :, top:down, left:right] += crop_x[:, index]index += 1for i in range(step, h + step - ps, step):top = idown = i + ps - stepif top + ps > h:top = h - psoutput[:, :, top:down, :] /= 2for j in range(step, w + step - ps, step):left = jright = j + ps - stepif left + ps > w:left = w - psoutput[:, :, :, left:right] /= 2return outputclass PreNorm(nn.Module):"""Normalization layer.Args:dim (int): Base channels.fn (Module): Module after normalization."""def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class Attention(nn.Module):"""Attention module.Args:dim (int): Base channels.heads (int): Head numbers.qk_dim (int): Channels of query and key."""def __init__(self, dim, heads, qk_dim):super().__init__()self.heads = headsself.dim = dimself.qk_dim = qk_dimself.scale = qk_dim ** -0.5self.to_q = nn.Linear(dim, qk_dim, bias=False)self.to_k = nn.Linear(dim, qk_dim, bias=False)self.to_v = nn.Linear(dim, dim, bias=False)self.proj = nn.Linear(dim, dim, bias=False)def forward(self, x):q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))out = F.scaled_dot_product_attention(q, k, v)  # scaled_dot_product_attention 需要PyTorch2.0之后版本out = rearrange(out, 'b h n d -> b n (h d)')return self.proj(out)class dwconv(nn.Module):def __init__(self, hidden_features, kernel_size=5):super(dwconv, self).__init__()self.depthwise_conv = nn.Sequential(nn.Conv2d(hidden_features, hidden_features, kernel_size=kernel_size, stride=1,padding=(kernel_size - 1) // 2, dilation=1,groups=hidden_features), nn.GELU())self.hidden_features = hidden_featuresdef forward(self, x, x_size):x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous()  # b Ph*Pw cx = self.depthwise_conv(x)x = x.flatten(2).transpose(1, 2).contiguous()return xclass ConvFFN(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, kernel_size=5, act_layer=nn.GELU):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.dwconv = dwconv(hidden_features=hidden_features, kernel_size=kernel_size)self.fc2 = nn.Linear(hidden_features, out_features)def forward(self, x, x_size):x = self.fc1(x)x = self.act(x)x = x + self.dwconv(x, x_size)x = self.fc2(x)return xclass LRSA(nn.Module):"""Attention module.Args:dim (int): Base channels.num (int): Number of blocks.qk_dim (int): Channels of query and key in Attention.mlp_dim (int): Channels of hidden mlp in Mlp.heads (int): Head numbers of Attention."""def __init__(self, dim, qk_dim=36, mlp_dim=96, heads=4):super().__init__()self.layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads, qk_dim)),PreNorm(dim, ConvFFN(dim, mlp_dim))])def forward(self, x, ps=16):step = ps - 2crop_x, nh, nw = patch_divide(x, step, ps)  # (b, n, c, ps, ps)b, n, c, ph, pw = crop_x.shapecrop_x = rearrange(crop_x, 'b n c h w -> (b n) (h w) c')attn, ff = self.layercrop_x = attn(crop_x) + crop_xcrop_x = rearrange(crop_x, '(b n) (h w) c  -> b n c h w', n=n, w=pw)x = patch_reverse(crop_x, x, step, ps)_, _, h, w = x.shapex = rearrange(x, 'b c h w-> b (h w) c')x = ff(x, x_size=(h, w)) + xx = rearrange(x, 'b (h w) c->b c h w', h=h)return x

1.2.2 添加过程

第一步:在nn文件夹下面创建一个新的文件夹,名字叫做Addmodules, 然后再这个文件夹下面创建LRSA.py文件,然后将LRSA的核心代码放入其中。

第二步:继续在Addmodules文件夹下面创建__init__.py文件,输入以下代码:


from .LRSA import LRSA
__all__ = ("LRSA",)

第三步在task.py文件中导入LRSA

from .Addmodules import LRSA
 elif m is LRSA:c1 = ch[f]args = [c1, *args]

第四步:修改yaml文件

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license# Ultralytics YOLO11 object detection model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo11
# Task docs: https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11n.yaml with scale 'n'# [depth, width, max_channels]n: [0.50, 0.25, 1024] # summary: 181 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss: [0.50, 0.50, 1024] # summary: 181 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm: [0.50, 1.00, 512] # summary: 231 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl: [1.00, 1.00, 512] # summary: 357 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx: [1.00, 1.50, 512] # summary: 357 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs# YOLO11n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 2, C3k2, [256, False, 0.25]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 2, C3k2, [512, False, 0.25]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 2, C3k2, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 2, C3k2, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 2, C2PSA, [1024]] # 10- [-1, 3, LRSA, [1024]]
# YOLO11n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 2, C3k2, [512, False]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)

请多多关注!!!

相关文章:

  • Android屏幕适配利器:Kotlin动态尺寸计算工具类完整封装
  • 矩阵:线性代数在AI大模型中的核心支柱
  • Kotlin-数组,集合类以及序列
  • SQLMesh Typed Macros:让SQL宏更强大、更安全、更易维护
  • 探索C++对象模型:This指针与(构造、析构)成员函数的深度解析(中篇)
  • iOS使用Metal对采集视频进行渲染
  • OpenHarmony外设驱动使用 (十三),Vibrator
  • Java桌面应用开发详解:自制截图工具从设计到打包的全流程【附源码与演示】
  • 2025年渗透测试面试题总结-匿名[社招]安全工程师(红队方向)2(题目+回答)
  • Linux(5)——再谈操作系统
  • 【AS32X601驱动系列教程】SMU_系统时钟详解
  • RNN GRU LSTM 模型理解
  • 飞桨(PaddlePaddle)在机器学习全流程(数据采集、处理、标注、建模、分析、优化)
  • 前端vue2-完全前端生成pdf->pdf-lib,html2canvas+jspdf,原生打印,三种方式(打印带有echarts图的pdf)
  • 可视化大屏实现全屏或非全屏
  • 继电保护与安全自动装置:电力系统安全的守护神
  • Windows 安装 FFmpeg 新手教程(附环境变量配置)
  • ProfiNet转Ethernet/IP网关选型策略适配西门子S7-1500与三菱变频器的关键参数对比
  • Oracle Apps R12——报表入门2:单表——报表开发流程
  • .NET外挂系列:6. harmony中一些实用的反射工具包
  • html怎么制作动态网站/今日新闻事件
  • 网站运营如何做/站长工具之家
  • 关于公司网站建设的请示/抚顺优化seo
  • 九江网站开发公司/seo整站优化新站快速排名
  • 微信网站建设报价表/黄页网络的推广
  • dw设计做网站完整案例/可以直接进入的舆情网站