读论文--Vision Transformer with Deformable Attention(DAT)完整教程
Vision Transformer with Deformable Attention(DAT)完整教程
作者:清华大学、AWS、北京AI研究院
发表:2022年5月(CVPR 2022投稿)
论文链接:https://arxiv.org/abs/2201.00520
代码仓库:https://github.com/LeapLabTHU/DAT
难度评级:⭐⭐⭐⭐(高级,需要Transformer基础)
一、论文核心摘要
1.1 研究背景
Vision Transformer的成功表明,自注意力机制在计算机视觉中具有强大的表达能力。然而,现有的vision transformer架构在注意力模式的设计上存在根本的矛盾:
┌─────────────────────────────────────────────┐
│ Vision Transformer 注意力机制的两难困境 │
├─────────────────────────────────────────────┤
│ │
│ 方案A:ViT(完整注意力) │
│ ├─ 优点:能建立所有位置的交互 │
│ ├─ 缺点:O(N²)复杂度,内存占用巨大 │
│ └─ 问题:被不相关区域干扰,容易过拟合 │
│ │
│ 方案B:Swin/PVT(稀疏注意力) │
│ ├─ 优点:计算高效,O(N·w²)复杂度 │
│ ├─ 缺点:固定的注意力模式(数据无关) │
│ └─ 问题:可能丢弃重要信息,无法自适应 │
│ │
└─────────────────────────────────────────────┘
1.2 论文的核心问题
问题陈述:现有的注意力优化要么计算太贵(完整注意力),要么太生硬(固定模式)。我们能否设计一种计算高效且数据自适应的注意力机制?
理想目标:
- ✅ 计算高效(不能是O(N²))
- ✅ 数据自适应(根据输入调整注意力)
- ✅ 保留重要信息(不丢弃关键特征)
- ✅ 建立长距离依赖(处理多尺度物体)
1.3 论文的创新方案
DAT(Deformable Attention Transformer)提出了可变形注意力机制:
关键洞察:观察(GCNet等研究发现)↓不同Query的注意力模式通常相似↓不需要每个Query都学习独立偏移↓
核心想法:用共享的、学习的偏移量引导注意力↓参考点 + 学习的偏移 → 变形点 → 采样特征↓
实现特点:• 线性空间复杂度 O(H·W·r²)• 完全可微(双线性插值)• 数据自适应(Offset Network)• 实现简洁(可在PyTorch中几行代码实现)
二、技术原理深度解析
2.1 标准自注意力回顾
首先回顾标准的多头自注意力(MHSA):
# 标准Self-Attention的计算流程# 输入:特征图 x ∈ R^(N×C),其中N=H×W是patch数量# 步骤1:投影得到Q、K、V
q = x @ W_q # Q ∈ R^(N×C)
k = x @ W_k # K ∈ R^(N×C)
v = x @ W_v # V ∈ R^(N×C)# 步骤2:计算注意力权重
# 这里是计算量的主要来源!
scores = (q @ k.T) / sqrt(d) # (N×N) 的计算,复杂度O(N²)
attention = softmax(scores) # N×N 的矩阵# 步骤3:聚合特征
output = attention @ v # (N×N) @ (N×C) = (N×C)
问题分析:
对于56×56的特征图(来自224×224的图像):
- N = 56 × 56 = 3,136个patch
- 注意力矩阵大小:3,136 × 3,136 ≈ 980万个元素
- 内存占用:980万 × 4字节(float32)≈ 40MB(这还只是一个头)
当使用16个注意力头时,内存占用会成倍增加,这就是为什么ViT在大图像上不可行的原因。
2.2 现有改进方案的局限
方案A:Swin Transformer的窗口注意力
# Swin的做法:限制在窗口内计算注意力def window_attention(features, window_size=7):"""将特征图分成不重叠的窗口,在每个窗口内做自注意力"""H, W, C = features.shape# 将特征分成7×7的窗口windows = features.reshape(H // window_size, window_size,W // window_size, window_size, C)# 在每个7×7窗口内做自注意力# 每个窗口的计算量:49×49 = 2,401(远小于3,136×3,136)# 总的计算复杂度从O(N²)变成O(N·w²)# 其中w是窗口大小
优点:计算量大幅降低(从O(N²)到O(N·w²))
缺点:
问题1:硬边界问题
┌─────────┬─────────┐
│ 窗口1 │ 窗口2 │
│ 狗 │ 耳朵 │ ← 狗的身体被窗口边界切开
└─────────┴─────────┘
结果:身体和耳朵无法建立联系问题2:数据无关
- 所有图像使用同样的窗口划分
- 无论内容是什么,都是这样分
- 不能根据物体位置自适应调整问题3:长距离依赖较弱
- 两个远距离的patch要交互
- 需要通过多个层传递信息
- 导致感受野增长缓慢
方案B:PVT的键值下采样
def pvt_attention(features):"""PVT的改进:下采样键值,保持查询分辨率"""H, W, C = features.shape# Query保持原始分辨率q = features @ W_q # (H×W) × C# Key和Value下采样到1/4downsampled = adaptive_pool(features, scale=2)k = downsampled @ W_k # (H/2×W/2) × Cv = downsampled @ W_v # (H/2×W/2) × C# 注意力计算# 复杂度从O(N²)变成O(N·N/4) = O(N²/4)
优点:计算量降低(通过减少Key/Value)
缺点:
问题1:信息丢失
- 高分辨率的细节被下采样后丢失
- 特别不利于检测小物体问题2:无法捕捉局部细节
- 小物体的fine-grained特征丢失
- 导致性能下降问题3:依然数据无关
- 固定的下采样比例
- 不能根据内容调整
2.3 可变形注意力的解决方案
核心想法的演进
灵感来源:Deformable Convolution Networks (DCN)↓在CNN中,DCN学习灵活的感受野每个像素学习自己的偏移量↓为什么不在Transformer中用?↓朴素移植会导致O(N²)的复杂度↓
关键观察:GCNet和DeepViT的发现↓不同Query的注意力权重分布相似↓不需要每个Query独立学习偏移↓
DAT的方案:共享偏移 + 偏移分组↓保持线性复杂度 + 获得数据自适应性
核心机制:参考点 + 偏移
参考点生成(Reference Points Generation)
───────────────────────────────────────────输入特征图:H × W × C第1步:生成统一的参考点网格• 原始特征大小:H × W• 参考点网格大小:H/r × W/r(r通常=8)• 坐标:线性间隔的2D坐标• 范围:[-1, +1](归一化到特征图坐标系)例子(H=56, W=56, r=8):参考点网格:7 × 7 = 49个点均匀分布在整个特征图上第2步:学习偏移量(Learned Offsets)• 输入:Query特征 q = x @ W_q• 通过Offset Network学习偏移• 输出:每个参考点的偏移 ΔpΔp = θ_offset(q)Δp ∈ R^(7×7×2) # 每个参考点有(Δx, Δy)两个值为了稳定训练,缩放偏移:Δp_scaled = s · tanh(Δp) # s是缩放因子,如s=6第3步:生成变形点(Deformed Points)新的采样位置 = 参考点 + 偏移p_deformed = p_ref + Δp_scaled这些新的位置可能不在整数坐标上例如:(3.2, 4.7)这样的浮点坐标第4步:采样特征(Feature Sampling)使用双线性插值在新位置采样:x̃ = bilinear_interpolation(x, p_deformed)这保证了可微性(梯度能反向传播)
双线性插值的数学原理
def bilinear_interpolation(feature_map, coord):"""在浮点坐标处进行双线性插值feature_map: (H, W, C)coord: (2,) 浮点坐标 [x, y]"""px, py = coord# 找到最近的4个整数坐标x0, y0 = int(floor(px)), int(floor(py))x1, y1 = x0 + 1, y0 + 1# 计算插值权重(距离权重)# g(a, b) = max(0, 1 - |a - b|)wx0 = 1 - (px - x0) # 到左边界的权重wx1 = px - x0 # 到右边界的权重wy0 = 1 - (py - y0) # 到上边界的权重wy1 = py - y0 # 到下边界的权重# 双线性插值value = (wx0 * wy0 * feature_map[x0, y0] +wx1 * wy0 * feature_map[x1, y0] +wx0 * wy1 * feature_map[x0, y1] +wx1 * wy1 * feature_map[x1, y1])return value
关键特性:
- 处处可微:梯度能平滑传递
- 效率高:只涉及4个最近邻的线性组合
- 实现简洁:PyTorch中torch.grid_sample即可
2.4 Offset Network设计
偏移量不能随意学习,需要有结构化的设计:
class OffsetNetwork(nn.Module):"""生成参考点的偏移输入:Query特征(已投影)输出:每个参考点的偏移量"""def __init__(self, in_channels, out_channels=2):super().__init__()# 第一部分:5×5深度卷积# 为什么深度卷积?# - 感知局部特征(每个参考点覆盖的区域)# - 参数量更少# - 有利于学习局部有意义的偏移self.dw_conv = nn.Conv2d(in_channels, in_channels,kernel_size=5, padding=2,groups=in_channels, # 深度卷积bias=True)# 激活函数self.gelu = nn.GELU()# 第二部分:1×1卷积,无偏置# 为什么无偏置?# - 偏置会强制所有位置都有偏移# - 我们希望某些位置偏移为0# - 所以去掉偏置让模型自由学习self.pw_conv = nn.Conv2d(in_channels, out_channels,kernel_size=1, bias=False)def forward(self, query_features):"""query_features: (B, H, W, C)output: (B, H_G, W_G, 2) # 每个参考点的(Δx, Δy)"""# 应用卷积序列x = self.dw_conv(query_features)x = self.gelu(x)offsets = self.pw_conv(x)return offsets
设计哲学:
为什么这样设计Offset Network?1. 深度卷积的局部感知├─ 参考点p覆盖s×s的区域├─ Offset Network应该感知这个局部├─ 5×5卷积足以覆盖这个上下文└─ 学到的偏移才有意义2. 无偏置的自由性├─ 有偏置 = 强制所有位置都偏移├─ 这不符合直觉├─ 某些位置应该偏移为0(已经在重要区域)└─ 无偏置让模型自己决定3. 两层卷积的表达力├─ 一层太简单├─ 三层太复杂├─ 两层是最优平衡└─ 在实践中证明有效
2.5 偏移分组增加多样性
class DeformableAttentionModule(nn.Module):"""可变形多头自注意力(DMHA)"""def __init__(self, dim, heads=8, offset_groups=3):super().__init__()self.heads = headsself.offset_groups = offset_groups# 关键设计:多个Offset子网络self.offset_networks = nn.ModuleList([OffsetNetwork(dim // offset_groups, out_channels=2)for _ in range(offset_groups)])# 为什么分组?# 原理类似多头注意力:# - 不同的头有不同的表示子空间# - 每个子网络生成的偏移也应该多样# - 增加表现力# - 保持参数量不爆炸
关键参数设置:
假设总通道数 D=384,总头数 M=12,偏移分组数 G=3配置1:每个Offset组对应多个注意力头
├─ 每组通道数:384/3 = 128
├─ 每组对应头数:12/3 = 4个头
├─ 这4个头共享同一组偏移
└─ 好处:减少参数,保持多样性实际效果:
✓ G=3时:3组不同的偏移模式
✓ 但参数量只增加微小
✓ 性能提升明显
2.6 可变形相对位置偏置
标准的Swin使用固定的相对位置偏置表,但DAT中的Key位置是连续浮点坐标,需要适配:
def deformable_relative_position_bias(query_pos, # (N_q, 2)key_pos, # (N_s, 2) N_s是参考点数量bias_table # (2H-1, 2W-1) 预定义的位置表
):"""计算可变形的相对位置偏置关键问题:Key位置是浮点数,无法直接查表解决方案:在位置表上进行双线性插值"""# 计算相对距离(可能是浮点数)relative_dist = key_pos - query_pos # (N_q, N_s, 2)# 归一化到[-1, +1]范围(对应位置表的坐标系)relative_dist_normalized = normalize(relative_dist)# 在位置表上进行双线性插值position_bias = bilinear_interpolate(bias_table, relative_dist_normalized)return position_bias # (N_q, N_s)
与标准Swin的对比:
┌─────────────────────────┬──────────────────┐
│ 特性 │ Swin │ DAT │
├─────────────────────────┼───────┼──────────┤
│ 位置表大小 │ 固定 │ 固定 │
│ 支持的坐标范围 │ 有限 │ 无限 │
│ 相对距离类型 │ 整数 │ 浮点 │
│ 查表方式 │ 直接 │ 插值 │
│ 梯度流 │ 离散 │ 连续 │
│ 适配可变形点 │ 否 │ 是 │
└─────────────────────────┴───────┴──────────┘
三、网络架构详解
3.1 整体架构
输入图像 (H × W × 3)↓
[Patch Embedding: 4×4卷积 stride=4]↓
Stage 1: 特征图大小 56×56
├─ Local Attention × N1
├─ Shift-Window Attention × N1 ← 为什么这里不用可变形?
├─ 原因:高分辨率,计算量大
└─ 采样点太多,开销太高↓ [下采样: 2×2卷积 stride=2]Stage 2: 特征图大小 28×28
├─ Local Attention × N2
├─ Shift-Window Attention × N2 ← 仍用Swin
└─ 原因同上↓ [下采样: 2×2卷积 stride=2]Stage 3: 特征图大小 14×14 ★ 开始使用可变形注意力
├─ Local Attention × N3
├─ Deformable Attention × N3 ← 开始这里
└─ 理由:分辨率降低,开销可承受需要全局建模,可变形有优势↓ [下采样: 2×2卷积 stride=2]Stage 4: 特征图大小 7×7 ★ 充分发挥可变形优势
├─ Local Attention × N4
├─ Deformable Attention × N4 ← 最后阶段
└─ 理由:最低分辨率,计算最高效最需要全局交互,可变形最有效↓
输出头
├─ 分类任务:[Global Average Pooling] → [Linear Classifier]
├─ 检测任务:+ FPN + Detection Head
└─ 分割任务:+ Decoder + Segmentation Head
3.2 Local Attention模块
[Local Attention]↓
在7×7的窗口内计算标准自注意力目的:聚集窗口内的局部信息
特点:
- 高效(O(w²)复杂度,w=7)
- 鲁棒(已被Swin验证)
- 必要(后续模块的基础)
3.3 Deformable Attention模块
[Deformable Attention] 详细流程输入:Query特征 q ∈ R^(H×W×C)第1步:生成参考点参考点网格:H/8 × W/8 个点范围:[-1, +1]归一化坐标第2步:学习偏移偏移 = OffsetNetwork(q)缩放:Δp = s·tanh(Δp)第3步:生成变形点p_def = p_ref + Δp第4步:采样特征x̃ = bilinear_interpolate(x, p_def)第5步:投影K、Vk̃ = x̃ @ W_kṽ = x̃ @ W_v第6步:计算注意力A = softmax(q @ k̃^T / √d + bias)output = A @ ṽ输出:变形后的特征
3.4 为什么交替使用Local和Deformable?
Block设计模式:[Local] → [Deformable] → [Local] → ...第1个Block:Local Attention↓聚集相邻patch的局部信息形成更新、更有表现力的特征第2个Block:Deformable Attention↓在这个更新的特征基础上灵活建立全局关系引导特征流向重要区域第3个Block:Local Attention↓重新处理已经具有全局上下文的特征进一步精化局部表示...循环迭代效果:
✓ 信息逐层完善
✓ 局部→全局→局部...多尺度循环
✓ 类似人类视觉系统(看整体→看细节)
3.5 DAT的三个模型变体
┌─────────┬──────────┬──────────┬──────────┐
│ 参数 │ DAT-T │ DAT-S │ DAT-B │
├─────────┼──────────┼──────────┼──────────┤
│ 参数量 │ 29M │ 50M │ 88M │
│ FLOPs │ 4.6G │ 9.0G │ 15.8G │
│ Channels│ 96-768 │ 96-768 │ 128-1024 │
│ 应用 │ 轻量级 │ 通用 │ 高性能 │
└─────────┴──────────┴──────────┴──────────┘设计选择的灵活性:
✓ 可以任意调整Stage的深度
✓ 可以任意调整通道数
✓ 可以任意调整注意力头数
✓ 可以自定义偏移分组数
四、实验结果完整分析
4.1 图像分类(ImageNet-1K)
分类精度对比
基准:Swin Transformer模型规模 参数 FLOPs Swin精度 DAT精度 提升
─────────────────────────────────────────────────────
Tiny 29M 4.5G 81.3% 82.0% +0.7%
Small 50M 8.8G 83.0% 83.7% +0.7%
Base 88M 15.5G 83.5% 84.0% +0.5%高分辨率微调(384×384)
Base 88M 47.2G 84.5% 84.8% +0.3%
关键观察:
1. 稳定的性能提升├─ 小模型:+0.7%(很显著)├─ 中模型:+0.7%(很显著)└─ 大模型:+0.5%(仍有提升)2. 在参数相同/接近的情况下✓ DAT始终优于Swin✓ 计算量只增加~2%✓ 但精度提升明显3. 高分辨率微调仍保持优势✓ 即使在384×384这样的高分辨率✓ DAT依然领先Swin✓ 说明可变形注意力的普适性
对比其他方法
所有竞争对手(Small模型级别)方法 参数 FLOPs 准确率
───────────────────────────────────
PVT-M 46M 6.9G 81.2%
DPT-M 46M 6.9G 81.9%
Swin-S 50M 8.8G 83.0%
DAT-S 50M 9.0G 83.7% ✓ 最佳
─────────────────────────────────────Tiny模型级别DPT-S 26M 4.0G 81.0%
Swin-T 29M 4.5G 81.3%
DAT-T 29M 4.6G 82.0% ✓ 最佳
4.2 目标检测(COCO)
RetinaNet一阶段检测器
配置:使用RetinaNet检测框架
输入分辨率:1280×800
训练计划:1×和3×两种1× 训练计划结果(12个Epoch):模型 参数 FLOPs AP AP50 AP75 改进
───────────────────────────────────────────────────
Swin-T 38M 248G 41.7 63.1 44.3 -
DAT-T 38M 253G 42.8 64.4 45.2 +1.1%Swin-S 60M 339G 44.5 66.1 47.4 -
DAT-S 60M 359G 45.7 67.7 48.5 +1.2%3× 训练计划结果(36个Epoch,训练更充分):Swin-T 38M 248G 44.8 66.1 48.0 -
DAT-T 38M 253G 45.6 67.2 48.5 +0.8%Swin-S 60M 339G 47.3 68.6 50.8 -
DAT-S 60M 359G 47.9 69.6 51.2 +0.6%
Mask R-CNN二阶段检测器
这通常给出更高精度的结这通常给出更高精度的结果1× 训练计划:模型 参数 FLOPs AP_box AP_mask 改进
──────────────────────────────────────────
Swin-T 48M 267G 43.7 39.8 -
DAT-T 48M 272G 44.4 40.4 +0.7%Swin-S 69M 359G 45.7 41.1 -
DAT-S 69M 378G 47.1 42.5 +1.4%3× 训练计划(更多数据增强、更长训练):Swin-T 48M 267G 46.0 41.6 -
DAT-T 48M 272G 47.1 42.4 +1.1%Swin-S 69M 359G 48.5 43.3 -
DAT-S 69M 378G 49.0 44.0 +0.5%Cascade Mask R-CNN(最高精度框架)3×训练:Swin-T 86M 745G 50.4 43.7 -
DAT-T 86M 750G 51.3 44.5 +0.9%Swin-S 107M 838G 51.9 45.0 -
DAT-S 107M 857G 52.7 45.5 +0.8%Swin-B 145M 982G 51.9 45.0 -
DAT-B 145M 1003G 53.0 45.8 +1.1%
不同物体大小的性能分析
这是DAT最显著的优势所在!物体大小分类:
- AP_s:小物体(面积 < 32²)
- AP_m:中等物体(32² < 面积 < 96²)
- AP_l:大物体(面积 > 96²)DAT相比Swin的改进(在Cascade Mask R-CNN上):物体大小 Swin-T DAT-T 改进 改进倍数
────────────────────────────────────────────
小物体 30.4 34.1 +3.7% ★★★★
中等物体 51.5 54.6 +3.1% ★★★
大物体 63.1 66.9 +3.8% ★★★★发现模式:
✓ 小和大物体的改进都非常显著(+3-4%)
✓ 中等物体改进相对较小(+3%)为什么会这样?小物体的好处分析:
├─ 需要精细的细节特征
├─ Swin的窗口可能不够灵活
├─ 可变形注意力能自适应采样细节区域
└─ 显著提升小物体检测精度大物体的好处分析:
├─ 需要建立物体不同部分的关系
├─ Swin的硬边界可能切断物体
├─ 可变形注意力能灵活跨越距离
├─ 建立整体物体的连贯表示
└─ 显著提升大物体检测精度中等物体的情况:
├─ 既有局部又有全局信息
├─ 两种方法都能相对好地处理
└─ 改进空间相对有限
4.3 语义分割(ADE20K)
ADE20K是最复杂的任务,需要像素级精度使用SemanticFPN框架(轻量级):模型 参数 FLOPs mIoU 改进
─────────────────────────────────
PVT-S 28M 225G 41.95 -
DAT-T 32M 198G 42.56 +0.6% ★ 用更少FLOPs!PVT-M 48M 315G 42.91 -
DAT-S 53M 320G 46.08 +3.2% ★ 大幅提升!PVT-L 65M 420G 43.49 -
DAT-B 92M 481G 47.02 +3.5%Swin-T 60M 945G 44.51 -
DAT-T 60M 957G 45.54 +1.0%Swin-S 81M 1038G 47.64 -
DAT-S 81M 1079G 48.31 +0.7%Swin-B 121M 1188G 48.13 -
DAT-B 121M 1212G 49.38 +1.2%使用UPerNet框架(高端):Swin-T 60M 945G 44.51 -
DAT-T 60M 957G 45.54 +1.0%Swin-S 81M 1038G 47.64 -
DAT-S 81M 1079G 48.31 +0.7%Swin-B 121M 1188G 48.13 -
DAT-B 121M 1212G 49.38 +1.2%多尺度测试(MS IOU):Swin-B 121M 1188G 49.72 -
DAT-B 121M 1212G 50.55 +0.8%
分割任务的特殊性:
为什么分割任务也有显著提升?分割 vs 分类 vs 检测:1. 分类(全图决策)- 只需要整体判断- 单尺度特征基本够- 改进空间有限2. 检测(区域定位)- 需要定位多个物体- 物体有不同尺寸- 需要多尺度交互3. 分割(像素级精度)★最需要多尺度- 需要精确的边界- 需要语义一致性- 需要局部细节+全局上下文- 需要像素级的准确度DAT在分割上的优势:
✓ 可变形注意力能灵活聚焦边界区域
✓ 能在局部和全局间平衡
✓ 特别适合需要精细特征的任务
五、消融研究与设计验证
5.1 可变形偏移与位置偏置的协同效应
Table 6:几何信息的利用配置 FLOPs 参数 准确率 vs基线
─────────────────────────────────────────────────────
无任何改进(基础Swin) 4.51G 28.29M 81.3% -0.7%仅有相对位置偏置 4.57G 28.32M 81.7% -0.3%
(没有偏移)仅有偏移采样 4.58G 28.29M 81.7% -0.3%
(没有位置偏置)偏移 + 固定位置编码 4.58G 29.73M 81.8% -0.2%偏移 + DWConv位置编码 4.59G 28.31M 81.8% -0.2%偏移 + 相对位置偏置 4.59G 28.32M 82.0% +0.0% ← 完整DAT
(DAT完整方案)关键发现:
═══════════════════════════════════1. 单独使用的效果├─ 仅偏移:+0.3%改进(不错)├─ 仅相对位置偏置:+0.3%改进(同样好)└─ 两者都差不多2. 组合使用的效果├─ 偏移 + 相对位置偏置:+0.7%改进├─ vs仅用偏移的+0.3%├─ 总改进:0.3% + 0.3% = 0.6% < 0.7%└─ 存在+0.1%的协同效应!3. 启示✓ 两个组件都有用✓ 组合时有超加性效应✓ 说明设计的巧妙性✓ 不是简单的加法,而是乘法
5.2 不同阶段应用可变形注意力
Table 7:在哪些阶段使用可变形注意力最优?应用阶段 FLOPs 参数 准确率 对比
─────────────────────────────────────────────
Stage 1,2,3,4 4.64G 28.39M 81.7% -0.3%
(所有阶段都用)Stage 2,3,4 4.60G 28.34M 81.9% -0.1%
(第1阶段除外)Stage 3,4 4.59G 28.32M 82.0% +0.0% ← 最优
(只有后两阶段)Stage 4 4.51G 28.29M 81.4% -0.6%
(仅最后阶段)Swin-T基线 4.51G 28.29M 81.3% 基准分析结果:
═════════════════════════════════1. 全局应用的问题(所有4个阶段)├─ 精度反而下降到81.7%├─ 为什么?高分辨率阶段开销太大├─ Stage 1(56×56), Stage 2(28×28)├─ 参考点太多 → 双线性插值计算量大├─ 收益小但成本高└─ 划不来2. 早期阶段的低效性├─ Stage 1, 2的分辨率很高├─ 参考点数量:56×56/8 = 49个├─ 双线性插值计算量 ∝ 参考点数├─ 成本高但改进小├─ 因为这些阶段主要学习局部特征└─ Swin的窗口注意力已经足够3. 后期阶段的高效性├─ Stage 3(14×14), Stage 4(7×7)├─ 参考点数量:14×14/8 = 3个(太少)│ 实际:可以调整r值,或者直接用全部点├─ 但计算量相对很低├─ 这些阶段需要全局建模├─ 可变形注意力优势显著└─ 效率和效果的完美平衡4. 最优策略✓ Stage 1-2:Shift-Window(高效且已验证)✓ Stage 3-4:Deformable(低分辨率,效果好)✓ 这是设计的智慧:因地制宜
5.3 偏移范围因子(s)的鲁棒性
Figure 4:不同s值下的性能偏移范围(s) 准确率 性能评价
────────────────────────────
0 80.6% ✗ 最差(无偏移,退化为Swin)
2 81.9% ✓ 良好
4 81.95% ✓ 良好
6 82.0% ★ 最佳
8 81.98% ✓ 良好
10 81.9% ✓ 良好
12 81.8% ✓ 中等
14 81.7% ✓ 中等
16 81.6% ✗ 变差(超出最大合理偏移)关键特征:
═══════════════════1. 宽泛的有效范围├─ 从s=2到s=14├─ 性能相对稳定├─ 差异仅0.3%└─ 这是很好的鲁棒性证明2. 最优值附近├─ s=6是性能顶峰├─ 但s=2-8的性能差异很小├─ 实践中选s=2就足够好└─ 不需要精细调参3. 边界效应├─ s太小(=0):无法学习有意义的偏移├─ s太大(=16):超出特征图范围,无效└─ 合理的范围内性能都不错4. 设计启示✓ 好的设计应该对超参数鲁棒✓ 不应该过度依赖精细调参✓ DAT通过这个测试✓ 实用性强
5.4 与Deformable DETR的对比
Appendix A详细对比Deformable DETR的问题:
─────────────────────1. 设计理念的差异Deformable DETR:├─ 用线性投影预测注意力权重├─ A = σ(W_att @ x) (W_att是参数矩阵)├─ 权重不是通过query-key交互计算└─ 更像卷积而非注意力DAT:├─ 使用真正的点积注意力├─ A = softmax(q @ k^T / √d)├─ 保持注意力机制的本质└─ Query和Key完整交互2. 实验对比(Table 8)配置 Keys FLOPs 内存 精度
──────────────────────────────────────────────
D-DETR (16个Key) 16 4.44G 13.9GB 80.6%
D-DETR (49个Key) 49 4.83G 18.8GB 80.7%
D-DETR (196个Key) 196 6.16G 37.9GB 79.2% ✗ 性能反而下降!DAT (49个Reference) 49 4.38G 12.5GB 81.8% ✓ 最优
DAT (196个Reference) 196 4.59G 14.4GB 82.0% ✓ 性能继续提升3. 结论在相同Key/Reference数量下:
├─ DAT的内存占用更低
├─ DAT的性能更好
├─ DAT的设计更优雅
└─ DAT更适合作为BackboneDAT相比D-DETR的优势:
✓ 真正的注意力机制(不是近似)
✓ 所有Key都参与计算(不丢弃信息)
✓ 高效的实现(共享偏移)
✓ 可扩展的设计(任意Key数量)
5.5 可视化验证
学习到的变形点
Figure 5-7的可视化分析观察1:前景对象聚焦
─────────────────
图像:野生动物(如长颈鹿)变形点分布:
◆ ◆ ● ◆ = 高权重(大圆圈)
◆ ◆ ● ● = 低权重(小圆圈)
● ●特点:
✓ 变形点集中在长颈鹿身体
✓ 忽略背景区域
✓ 自动学会了前景-背景分离对比Swin:
├─ Swin用固定7×7窗口
├─ 无论内容如何都是一样
└─ DAT能自适应调整观察2:多物体场景
─────────────────
图像:网球运动员挥拍变形点分布:
● △ ● ● = 人体相关点△ = 球拍相关点特点:
✓ 不同物体有不同变形模式
✓ 能同时处理多物体
✓ 每个物体的注意力独立优化
✓ 体现了数据自适应性观察3:精细几何形状
──────────────────
图像:多个甜甜圈变形点分布:
● ● ●
● ● ●特点:
✓ 甜甜圈边界清晰
✓ 每个物体都有专注点
✓ 保留了精细的几何特征
✓ 证明能处理复杂场景Swin vs DAT对比:
┌─────────────────────┬──────────────────┐
│ 特性 │ Swin │ DAT │
├─────────────────────┼────────┼──────────┤
│ 注意力模式 │ 固定 │ 自适应 │
│ 前景-背景区分 │ 弱 │ 强 │
│ 多物体处理 │ 平均 │ 精准 │
│ 边界处理 │ 硬切 │ 平滑 │
│ 物体完整性 │ 可能被切│ 保持 │
└─────────────────────┴────────┴──────────┘
注意力权重热力图
Figure 7详细分析Swin Transformer的注意力:
┌─────┬─────────────┐
│ ███ │ █████ █ │
│ ███ │ ██████ ██ │ ← 只关注窗口内
│ ███ │ ██████ ███ │ 无法跨窗口
└─────┴─────────────┘问题:长颈鹿头和身体在不同窗口无法建立联系DAT的注意力:
┌─────────────────────┐
│ ░░░░░░░░░░░░ │
│ ░░░░████░░░░ │ ← 自由形状
│ ░░░░██████░░░ │ 适应物体轮廓
│ ░░████████░░░ │
└─────────────────────┘优势:能跨越距离聚焦长颈鹿同时忽略无关背景启示:
✓ 数据自适应的设计更优
✓ 固定模式的局限性明显
✓ DAT充分利用了数据特性
六、核心创新与设计哲学
6.1 DAT的三层创新
创新层次1:技术层面
─────────────────问题:Transformer中的可变形设计如何保持高效?解决方案:
✓ 不学习每个Query的独立偏移(会导致O(N²))
✓ 而是学习共享的参考点偏移
✓ 通过偏移分组增加表达多样性
✓ 结果:O(N·r²)的线性复杂度创新:共享偏移而非独立偏移创新层次2:设计层面
──────────────────问题:在哪些阶段使用可变形注意力最优?解决方案:
✓ 不是简单地替换所有注意力
✓ 前期阶段(1-2)用Shift-Window- 原因:高分辨率,开销大- 特点:已验证有效,足以学习局部特征
✓ 后期阶段(3-4)用Deformable- 原因:低分辨率,开销可承受- 特点:最需要全局建模,最有优势创新:因地制宜的分阶段策略创新层次3:理论层面
──────────────────问题:如何使浮点坐标的位置编码与可变形点配合?解决方案:
✓ 在预定义的位置偏置表上进行双线性插值
✓ 保证可微性(梯度连续传播)
✓ 支持任意精度的相对距离
✓ 与Swin的相对位置偏置完全兼容创新:可变形相对位置偏置的巧妙设计
6.2 与CrossFormer的对比
两篇论文都在改进ViT,但创新点不同对比维度 CrossFormer DAT
────────────────────────────────────────────
核心问题 缺乏多尺度特征 注意力模式固定
改进位置 早期层(Embedding)后期层(Attention)
主要模块 CEL + LSDA 可变形注意力
设计理念 特征多样化 注意力自适应性能表现:任务 CrossFormer DAT 胜者
────────────────────────────────────────
分类 +0.2% +0.7% DAT
检测 +1.7% +1.2% CrossFormer
分割 +4.1% +1.2% CrossFormer特点总结:CrossFormer:
✓ 对分割任务特别优化
✓ 多尺度混合的想法
✓ CEL巧妙的维度分配
✓ 适合需要多尺度的任务DAT:
✓ 对各任务均衡优化
✓ 自适应注意力的想法
✓ 可变形相对位置偏置
✓ 适合需要灵活注意力的任务可以结合吗?
理论上完全可以:
├─ 前期用CEL提供多尺度特征
├─ 后期用可变形注意力灵活利用
└─ 可能会有更好的效果
6.3 深度学习研究的启示
启示1:观察驱动创新
──────────────────GCNet等研究的观察:
├─ 不同Query的注意力权重相似
├─ 这是一个实验事实DAT的利用:
├─ 基于这个观察设计共享偏移
├─ 避免了朴素的O(N²)复杂度教训:
✓ 好的观察 → 好的设计
✓ 理论来自实验
✓ 数据驱动设计启示2:权衡的艺术
────────────────完美设计面临的权衡:
├─ 高效 vs 高性能
├─ 全局 vs 局部
├─ 计算 vs 准确DAT的权衡策略:
├─ 前期高效(Swin),后期强大(Deformable)
├─ 局部和全局的循环交替
├─ 在保持计算可承受的前提下最大化性能教训:
✓ 没有完美的解决方案
✓ 合理的权衡往往最有效
✓ 不同阶段可以用不同策略启示3:充分验证的必要性
───────────────────验证的内容:
├─ 主任务性能(分类、检测、分割)
├─ 消融研究(每个部分的贡献)
├─ 敏感性分析(超参数鲁棒性)
├─ 可视化分析(学到了什么)
├─ 对比分析(vs其他方法)DAT的实验充分:
✓ 4个大型数据集
✓ 5个不同框架(RetinaNet、Mask R-CNN等)
✓ 详细的消融表格
✓ 鲁棒性分析
✓ 可视化证据教训:
✓ 科学的研究需要充分的验证
✓ 消融研究是论文可信度的标志
✓ 可视化能提供直观理解启示4:简洁性的力量
──────────────────DAT的特点:
├─ 核心思想简洁(共享偏移)
├─ 实现优雅(双线性插值)
├─ 代码量少
├─ 易于理解和复现对比复杂的设计:
├─ ✗ 堆砌多个技巧
├─ ✗ 难以理解本质
├─ ✗ 难以复现
└─ ✗ 难以改进教训:
✓ 简洁不是简单,而是精妙
✓ 好的设计应该优雅
✓ 优雅的设计更易接受和发展
七、实现建议与代码指引
7.1 核心模块的实现要点
# 双线性插值的实现
import torch
import torch.nn.functional as Fdef deformable_sampling(features, offsets):"""在可变形点处采样特征Args:features: (B, H, W, C) 输入特征offsets: (B, H_G, W_G, 2) 每个参考点的偏移Returns:sampled_features: (B, H_G, W_G, C) 采样后的特征"""B, H, W, C = features.shapeB_g, H_g, W_g, _ = offsets.shape# 生成参考点ys = torch.linspace(-1, 1, H_g, device=features.device)xs = torch.linspace(-1, 1, W_g, device=features.device)grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij')# 加上偏移sample_y = grid_y + offsets[:, :, :, 1] # [:, :, :, 1]是Δysample_x = grid_x + offsets[:, :, :, 0] # [:, :, :, 0]是Δx# 堆叠为采样网格grid = torch.stack([sample_x, sample_y], dim=-1)# 使用grid_sample进行双线性插值sampled = F.grid_sample(features.permute(0, 3, 1, 2), # (B, C, H, W)grid.unsqueeze(1), # (B, H_g, W_g, 2)mode='bilinear',padding_mode='zeros',align_corners=False)return sampled.permute(0, 2, 3, 1) # (B, H_g, W_g, C)# Offset Network的实现
class OffsetNet(nn.Module):def __init__(self, in_channels, out_channels=2):super().__init__()# 5×5深度卷积self.dw_conv = nn.Conv2d(in_channels, in_channels,kernel_size=5, padding=2,groups=in_channels)self.gelu = nn.GELU()# 1×1卷积生成偏移(无偏置)self.pw_conv = nn.Conv2d(in_channels, out_channels,kernel_size=1, bias=False)def forward(self, x):x = self.dw_conv(x)x = self.gelu(x)offsets = self.pw_conv(x)return offsets
7.2 模型使用建议
选择DAT的场景:✓ 需要高性能backbone的任务
✓ 有多个不同尺度物体的检测
✓ 对精度要求高(分割、医学影像)
✓ 可以承受小幅计算增加(+2%FLOPs)不选择DAT的场景:✗ 计算资源极其紧张
✗ 需要最快的推理速度
✗ 模型大小受严格限制