当前位置: 首页 > news >正文

特征金字塔网络(FPN)详解

基础介绍

特征金字塔网络是是一种多尺度特征表示方法,用于解决目标检测、图像分割等任务中对不同尺寸的目标的检测问题。特征金字塔网络会有多个尺度的特征图输出,每个大尺度的特征图都包含小尺度特征图的信息,进而将小尺度特征图中的信息融合到大尺度的特征图中,这样在大尺度中也会包含只有小尺度特征图的语义信息。

主要特点:

  • 多尺度特征图表示:这句话的意思是说特征金字塔网络会有多个输出,每个输出的尺度各不相同,但是通道数相同。这些不同的尺度就表示不同的分辨率,在特征金字塔网络输出的基础上,后面接入各种检测头就可以实现不同的功能,这些检测头需要对这些不同尺度的特征图进行预测。
  • 自顶向下的信息传递:这里的自顶向下是只从小尺度特征图向大尺度特征图方向传递信息,因为小尺度图特征图感受野大,包含极强的语义信息;而大尺度特征图则包含更多的空间信息、细节信息等。与之相对的是自底向上路径,它的意思是从大尺度特征图向小尺度方向的变化,即传统意义上的主干网络,实现特征图提取。
  • 横向连接的特征融合:每一个特征金字塔网络的输出特征图都包含小尺度特征的信息,这就是通过横向连接方式实现的,需要对低尺度的特征图进行上采样+修改通道数。通过横向连接特征融合实现了细节和语义的融合
  • 语义信息的层级显示

解决的主要问题

1.解决尺度不变性的问题

  • 解决目标检测中不同尺度目标的检测问题
  • 提供多尺度特征表示

2.特征表示问题

  • 低层特征:丰富的细节信息,但语义信息少
  • 高层特征:强语义信息,但位置信息少
  • FPN:两者优势结合

3.小目标检测问题

  • 利用低层特征的高分辨率
  • 同时包含高层的语义信息

多尺度融合的思想

多尺度融合是特征金字塔网络的核心。举个例子方便理解:

P2 P3 P4 P5分别表示特征金字塔网络的输出,他们的尺度不一样,通道数一样,后端检测头基于这些特征图进行检测,它们的侧重点也会有所不同。

P2:高分辨率,适合小目标
P3-P4:中等分辨率,适合中等目标
P5:低分辨率,适合大目标

特征金字塔的核心思想

 就像一个“完美的望远镜”。

1. 能同时看清大物体和小物体
2. 既能看到整体,又能看到细节
3. 多个尺度信息融合在一起

也就是特征金字塔网络既需要小尺度的高层特征,也需要低尺度的细节信息。 通过融合的方式来实现这个需求。

 特征金字塔网络实现要点

上采样操作

# PyTorch示例
import torch.nn.functional as F

def upsample_add(x, y):
    return F.interpolate(x, size=y.shape[2:], mode='nearest')

通过上面的例子看到是通过interpolate函数来是实现的上采样,而不是通过转置卷积。上面代码的含义是将x张量上采样为与y张量相同的尺度,即分辨率一样(尺度大小一样)。

特征融合

# 1x1卷积调整通道
lateral_conv = nn.Conv2d(in_channels, out_channels, 1)

# 特征融合
fused_feature = lateral_conv(c_feature) + upsample_feature

这里先通过1x1的卷积核改变张量的通道,然后与上采样的张量求和实现特征融合。

特征金字塔网络完整实例

基础模块实现

class FPNBlock(nn.Module):
    def __init__(self, C_in, C_out):
        super().__init__()
        # 横向连接的1x1卷积
        self.lateral = nn.Conv2d(C_in, C_out, 1)
        # 特征融合后的3x3卷积
        self.fuse = nn.Conv2d(C_out, C_out, 3, padding=1)
        
    def forward(self, x, upsampled=None):
        # 横向连接
        lateral = self.lateral(x)
        
        # 如果有上采样的特征,进行融合
        if upsampled is not None:
            lateral = lateral + upsampled
            
        # 3x3卷积处理融合后的特征
        out = self.fuse(lateral)
        return out, lateral

该基础模块的作用是实现两个特征图的融合。这个fpn基础网络模块前向传播函数核心逻辑如下:

  1. 首先利用lateral调整了特征图的通道数为统一通道数,确保所有的特征图虽然尺度大小不一样,但通道数是一样
  2. 如果存在上采样后的特征图,则将第一步调整后的特征图与上采样的特征图相加
  3. fuse实现对融合的特征图做一个平和处理,输出特征图尺度大小不变

完整fpn网络

class CompleteFPN(nn.Module):
    def __init__(self):
        super().__init__()
        # 假设使用ResNet主干网络
        self.backbone = resnet50(pretrained=True)
        
        # FPN层
        self.fpn_c2 = FPNBlock(256, 256)
        self.fpn_c3 = FPNBlock(512, 256)
        self.fpn_c4 = FPNBlock(1024, 256)
        self.fpn_c5 = FPNBlock(2048, 256)
        
    def forward(self, x):
        # 主干网络特征提取
        c2 = self.backbone.layer1(x)  # 1/4
        c3 = self.backbone.layer2(c2) # 1/8
        c4 = self.backbone.layer3(c3) # 1/16
        c5 = self.backbone.layer4(c4) # 1/32
        
        # 自顶向下路径
        p5, lateral5 = self.fpn_c5(c5)
        
        # 上采样p5并与c4融合
        up_5 = F.interpolate(lateral5, c4.shape[-2:], mode='nearest')
        p4, lateral4 = self.fpn_c4(c4, up_5)
        
        # 上采样p4并与c3融合
        up_4 = F.interpolate(lateral4, c3.shape[-2:], mode='nearest')
        p3, lateral3 = self.fpn_c3(c3, up_4)
        
        # 上采样p3并与c2融合
        up_3 = F.interpolate(lateral3, c2.shape[-2:], mode='nearest')
        p2, _ = self.fpn_c2(c2, up_3)
        
        return [p2, p3, p4, p5]

在上面的简单fpn的前向传播函数的实现逻辑如下所示:

  1. 利用主干网络实现下采样,特征提取,该步骤属于自底向上的路径操作
  2. 从最小尺度的特征图开始,首先对这个特征图调用fpn网络基础模块的前向传播函数,由于它是最后一层特征图,没有下一层,所以这里的基础模块仅仅将它的通道数从2048调整为了256,然后进行了平滑操作,返回两个值,一个是平滑后的特征图p5,一个是未经平滑的特征图(lateral5)
  3. 将lateral5进行上采样,将它的尺度变成与c4的尺度相同得到up_5,然后对c4调用fpn网络基础模块,由于c4不是最后一层特征图,那么它一定存在上采样后的特征图,这里是up_5,调用fpn_c4后得到c4和up_5融合后的新的特征图P4
  4. 其它特征图重复步骤3的操作,依次进行特征图的层层融合,最后得到p2,p3,p4,p5

使用实例

# 创建模型
model = CompleteFPN()

# 准备输入
image = torch.randn(1, 3, 800, 800)

# 前向传播
feature_maps = model(image)

# 查看各特征图尺寸
for i, fm in enumerate(feature_maps):
    print(f"P{i+2} shape:", fm.shape)

应用建议

# 常用的特征图尺寸比例
scales = {
    'P2': 1/4,   # 适合小物体
    'P3': 1/8,   # 适合中小物体
    'P4': 1/16,  # 适合中等物体
    'P5': 1/32   # 适合大物体
}

上面的1/4,1/8等的含义表示的是尺度的大小,即相对于原图的大小。

http://www.dtcms.com/a/109020.html

相关文章:

  • 【易订货-注册/登录安全分析报告】
  • Oracle触发器使用(二):伪记录和系统触发器
  • 构建个人专属知识库文件的RAG的大模型应用
  • BUUCTF-web刷题篇(9)
  • idea插件(自用)
  • video标签播放mp4格式视频只有声音没有图像的问题
  • NVIDIA显卡
  • 2.3 路径问题专题:剑指 Offer 47. 礼物的最大价值
  • Apollo配置中心登陆页面添加验证码
  • OpenCV销毁窗口
  • 浅谈软件成分分析 (SCA) 在企业开发安全建设中的落地思路
  • 数据库--SQL
  • Pytorch深度学习框架60天进阶学习计划 - 第34天:自动化模型调优
  • 维拉工时自定义字段:赋能项目数据的深度洞察 | 上新预告
  • React-router v7 第一章(安装)
  • JDBC常用的接口
  • coding ability 展开第八幕(位运算——基础篇)超详细!!!!
  • Spring Boot 集成 Redis 对哈希数据的详细操作示例,涵盖不同结构类型(基础类型、对象、嵌套结构)的完整代码及注释
  • PyQt6实例_A股日数据维护工具_使用
  • OpenCV 引擎:驱动实时应用开发的科技狂飙
  • 操作系统(一):概念及主流系统全分析
  • 大模型学习三:DeepSeek R1蒸馏模型组ollama调用流程
  • Vue2 生命周期
  • Adam vs SGD vs RMSProp:PyTorch优化器选择
  • 美关税加征下,Odoo免费开源ERP如何助企业破局?
  • 【无标题 langsmith
  • DNS域名解析过程 + 安全 / 性能优化方向
  • 在线下载国内外各种常见视频网站视频的网页端工具
  • frp 让服务器远程调用本地的服务(比如你的java 8080项目)
  • AIGC7——AIGC驱动的视听内容定制化革命:从Sora到商业化落地