Qwen2.5-VL技术详解
1. 关键技术原理
这篇题为《Qwen2.5-VL Technical Report》的技术报告详细介绍了阿里巴巴Qwen团队开发的最新多模态大模型Qwen2.5-VL。以下是对其技术原理的详细分析:
1.1. 模型架构(Model Architecture)
Qwen2.5-VL 的整体架构由三个核心组件构成:
1.1.1. 大语言模型(LLM)
- 基于 Qwen2.5 LLM 的预训练权重初始化。
- 引入了 Multimodal Rotary Position Embedding Aligned to Absolute Time(MRoPE),用于更好地处理多模态序列中的位置信息。
1.1.2. 视觉编码器(Vision Encoder)
- 采用重新设计的 Vision Transformer(ViT) 结构。
- 关键创新:
- 2D-RoPE:用于捕捉图像空间位置关系。
- Window Attention:在大多数层中使用窗口注意力机制,仅少数层使用全局注意力,显著降低计算复杂度(从平方降为线性)。
- 动态分辨率处理:输入图像按原生分辨率处理,不进行归一化,保留真实尺度信息。
- 3D Patch Partition:处理视频时,将连续两帧合并,减少 token 数量。
1.1.3. 视觉-语言融合器(Vision-Language Merger)
- 使用 两层MLP 对图像 patch 特征进行压缩和投影,使其与文本嵌入维度对齐。
- 将相邻的4个 patch 特征拼接后通过MLP,既减少计算量,又保持灵活性。
1.2. 关键技术亮点
1.2.1. 原生动态分辨率与帧率(Native Dynamic Resolution & Frame Rate)
- 空间维度:图像按原始尺寸处理,直接使用实际坐标表示边界框和点,模型能学习尺度信息。
- 时间维度:视频处理中引入 动态FPS采样 和 绝对时间编码,使模型能理解视频的时间动态。
1.2.2. 多模态旋转位置编码(MRoPE)
- 将位置编码分解为 时间、高度、宽度 三个维度。
- 在 Qwen2.5-VL 中,时间ID与绝对时间对齐,使模型能理解不同FPS下的时间一致性。
1.3. 训练策略(Training Strategy)
1.3.1. 预训练(Pre-Training)
- 数据规模从 1.2T tokens 扩展到 4.1T tokens。
- 数据来源多样:
- 图像-文本交错数据
- OCR数据(多语言支持)
- 文档解析数据(表格、图表、公式、乐谱等)
- 视频数据(动态FPS采样、长视频标注)
- 智能体交互数据(屏幕截图、UI元素标注、操作轨迹)
1.3.2. 后训练对齐(Post-Training Alignment)
- 分为两个阶段:
- SFT(Supervised Fine-Tuning):使用 ChatML 格式的多模态指令数据。
- DPO(Direct Preference Optimization):基于人类偏好进行优化。
1.3.3. 数据过滤与增强
- 使用 Qwen2-VL-Instag 进行领域分类和过滤。
- 拒绝采样(Rejection Sampling):保留高质量推理样本,提升模型推理能力。
- 多维度评分:包括图文相关性、信息互补性、信息密度平衡等。
1.4. 性能表现(Experiments)
Qwen2.5-VL 在多个基准测试中表现优异:
1.4.1. 通用视觉问答(General VQA)
- 在 MMBench、MMStar、MuirBench 等数据集上达到 SOTA。
1.4.2. 文档理解与OCR
- 在 CC-OCR、OmniDocBench、InfoVQA 等任务中领先。
1.4.3. 空间理解(Spatial Understanding)
- 在 RefCOCO、ODinW、CountBench 等任务中表现优异,支持边界框、点定位和计数。
1.4.4. 视频理解与定位(Video Understanding & Grounding)
- 在 LVBench、MLVU、Charades-STA 等长视频理解任务中超越 GPT-4o。
1.4.5. 智能体能力(Agent Capabilities)
- 在 ScreenSpot、Android Control、OSWorld 等 GUI 交互任务中表现突出。
1.5. 模型规模与适用场景
Qwen2.5-VL 提供三个版本:
- 72B:旗舰模型,性能媲美 GPT-4o、Claude 3.5 Sonnet。
- 7B:中等规模,性能优于同类竞品。
- 3B:轻量级,适合边缘设备,仍保持强大能力。
1.6. 总结
Qwen2.5-VL 的核心技术突破包括:
- 原生动态分辨率处理(空间+时间)
- MRoPE 时间对齐机制
- Window Attention 降低计算开销
- 高质量多模态数据构建与过滤
- 强大的文档解析、目标定位、视频理解能力
该模型不仅在多项基准测试中达到 SOTA,还具备强大的泛化能力和实际应用潜力,尤其在文档处理、视频分析和智能体交互方面表现突出。
2. MRoPE为什么能更好地处理多模态序列中的位置信息?
Multimodal Rotary Position Embedding Aligned to Absolute Time(MRoPE) 是 Qwen2.5-VL 相较于前代和同类模型的一个关键创新,它从“原理”层面解决了多模态序列,尤其是视频序列中位置信息编码的根本问题。
2.1. 概述
下面我们分步解析它为什么能更好地处理位置信息。
2.1.1. 首先,理解基础:RoPE (Rotary Position Embedding)
RoPE 是当今大语言模型(如 LLaMA, Qwen)的主流位置编码方案。其核心思想是:
- 通过旋转对位置进行编码:对于位置为
m
的 token,其查询(Q)和键(K)向量会被一个与m
相关的旋转矩阵所变换。 - 内在的相对位置感知:两个向量在经过旋转后,它们的点积(即注意力分数)只与它们的相对位置差
(m-n)
有关,而与它们的绝对位置m
和n
无关。这为模型理解单词顺序提供了强大的归纳偏置。
传统的 RoPE 是 1维 的,只为语言序列中 token 的先后顺序进行编码。
2.1.2. 问题所在:多模态序列的复杂性
当处理图像和视频时,序列变得复杂得多:
- 空间二维性:图像中的每个 patch 不仅有“顺序”,更有明确的 (x, y) 二维坐标。
- 时间维度的加入:视频是由一系列图像帧(2D)按时间顺序(第3维)组成的。
- 可变帧率(FPS)的挑战:不同的视频有不同的 FPS。一个关键直觉是:相隔 10 帧在两个视频中可能代表完全不同的时间跨度(例如,一个视频是 10帧/秒,10帧=1秒;另一个是 1帧/秒,10帧=10秒)。传统方法只编码“第几帧”,而无法理解“这一帧在绝对时间轴上的哪个时刻”。
2.1.3. MRoPE 的解决方案
MRoPE 的核心思想是:将高维(2D或3D)的位置信息分解到 RoPE 的每一个维度上。
a) 空间二维编码 (来自 Qwen2-VL)
MRoPE 将位置标识符分解为三个分量:( temporal, height, width )
。
- 对于文本 Token:所有三个分量都使用相同的值(即它在句子中的顺序位置),此时 MRoPE 退化为标准的 1D RoPE。
- 对于图像 Patch:
temporal
(时间)ID:对于静态图像,所有 patch 的 temporal ID 都相同(例如设为0)。height
和width
ID:根据该 patch 在图像中的实际二维坐标(i, j)
分别赋值。
这样,图像 patch 的注意力计算就同时融入了它在水平和垂直方向上的相对位置关系,模型能天然地理解“左上角”、“右下角”、“相邻”等空间概念。
b) 时间维度的绝对对齐 (Qwen2.5-VL 的关键升级)
Qwen2-VL 的 MRoPE 在处理视频时,temporal
ID 只是简单地递增(第0帧,第1帧,第2帧…)。这存在致命缺陷:它无法区分高FPS视频中的“连续帧”和低FPS视频中的“连续帧”在时间跨度上的巨大差异。
Qwen2.5-VL 的突破在于:将 temporal
ID 与绝对时间戳对齐。
- 具体做法:不再是
[0, 1, 2, 3, ...]
,而是根据视频的 FPS,将 temporal ID 设置为该帧对应的绝对时间戳(例如,单位是秒):[0.0, 0.1, 0.2, 0.3, ...]
(对于10FPS视频) 或[0.0, 1.0, 2.0, 3.0, ...]
(对于1FPS视频)。
2.1.4. 为什么“绝对时间对齐”如此有效?
这解决了视频理解中的一个根本性难题:
-
理解“时间尺度”和“事件节奏”:
- 通过使用绝对时间戳,模型能够感知到时间ID之间的间隔(ΔT)。
- 两个 temporal ID 相差
0.1
(秒)和相差1.0
(秒)在旋转角度上会产生完全不同的差异。 - 因此,模型能本能地学到:“这两个帧之间的事件发生得非常快” 或 “这两帧之间有一个很长的间隔,可能发生了缓慢的变化或跳跃”。这是传统递增ID永远无法提供的信息。
-
实现跨视频的时序一致性:
- 一个在
t=5s
的事件,无论在什么FPS的视频中,它的 temporal ID 都是5.0
。 - 这使得模型在不同FPS的视频中都能对“第5秒”有一个统一的位置编码表示,极大地增强了泛化能力。
- 一个在
-
精确的时间定位:
- 模型在输出时间戳时(例如,“请找出小狗叫的时刻”),其内部的时序位置编码与真实世界的时间轴是对应的。这直接提升了模型进行时间定位(Temporal Grounding) 的准确性。如表8所示,Qwen2.5-VL在Charades-STA上的表现远超GPT-4o。
2.1.5. 总结与类比
特性 | 传统1D RoPE | Qwen2-VL MRoPE | Qwen2.5-VL MRoPE (绝对时间对齐) |
---|---|---|---|
文本 | 处理良好 | 处理良好 | 处理良好 |
图像空间关系 | 无感知 | 优秀的2D相对位置感知 | 优秀的2D相对位置感知 |
视频时序关系 | 无感知 | 感知顺序,但不知节奏 | 优秀的时序感知,理解节奏和绝对时间 |
处理不同FPS视频 | 不一致 | 不一致 | 一致,具有泛化性 |
一个简单的类比:
- 传统方法(递增ID):就像只告诉你这是“第10页”的幻灯片,但不告诉你每页幻灯片停留了多久。
- MRoPE(绝对时间对齐):不仅告诉你这是“第10页”,还告诉你“现在是从开始后的第1分30秒”。让你能清晰地知道演讲的节奏是快是慢。
因此,MRoPE通过将高维空间和时间信息直接分解并注入到注意力机制的核心计算中,并使用绝对时间戳来统一时间表示,从而在根本原理上实现了对多模态序列(尤其是视频)中位置信息更精确、更泛化的编码。 这是Qwen2.5-VL在长视频理解、时间定位等任务上取得突破性进展的关键技术原因。
2.2. 公式、例子和代码
我们将通过公式、例子和代码三个层面,详细拆解 Qwen2.5-VL 中 Multimodal Rotary Position Embedding Aligned to Absolute Time (MRoPE) 的工作原理。
2.2.1. 核心思想与公式
MRoPE 建立在 RoPE 的基础之上。我们先回顾 RoPE 的核心公式。
a) 基础 RoPE 公式
对于位置为 m
的 token,其查询(Q)或键(K)向量的第 i
个维度对,应用旋转变换:
q~m(i)=qm(i)⋅eimθik~n(i)=kn(i)⋅einθi \begin{aligned} \tilde{q}_m^{(i)} &= q_m^{(i)} \cdot e^{im\theta_i} \\ \tilde{k}_n^{(i)} &= k_n^{(i)} \cdot e^{in\theta_i} \end{aligned} q~m(i)k~n(i)=qm(i)⋅eimθi=kn(i)⋅einθi
其中 θi\theta_iθi 是频率因子。计算注意力分数时:
⟨q~m,k~n⟩=Re[∑i=0d/2−1qm(i)(kn(i))∗ei(m−n)θi]=:⟨qm,kn⟩rope(m−n) \begin{aligned} \langle \tilde{q}_m, \tilde{k}_n \rangle &= \text{Re}[\sum_{i=0}^{d/2-1} q_m^{(i)} (k_n^{(i)})^* e^{i(m-n)\theta_i}] \\ &=: \langle q_m, k_n \rangle_{\text{rope}(m-n)} \end{aligned} ⟨q~m,k~n⟩=Re[i=0∑d/2−1qm(i)(kn(i))∗ei(m−n)θi]=:⟨qm,kn⟩rope(m−n)
注意力分数仅依赖于相对位置 m-n
。
b) MRoPE 扩展
MRoPE 将 1D 位置 m
扩展为三元组 (t, h, w)
,分别代表时间、高度、宽度维度。对每个维度独立应用 RoPE。
最终变换是三个维度旋转的复合。可以理解为向量先后在时间、高度、宽度三个维度上进行了旋转。查询向量的变换公式为:
q~(t,h,w)(i)=q(t,h,w)(i)⋅ei(tθit+hθih+wθiw) \tilde{q}_{(t, h, w)}^{(i)} = q_{(t, h, w)}^{(i)} \cdot e^{i(t\theta^t_i + h\theta^h_i + w\theta^w_i)} q~(t,h,w)(i)=q(t,h,w)(i)⋅ei(tθit+hθih+wθiw)
键向量的变换类似。最终的注意力分数依赖于三个维度上的相对位置:
(Δt, Δh, Δw) = (t_q - t_k, h_q - h_k, w_q - w_k)
2.2.2. 举例说明
假设我们有一个 2FPS(每秒2帧)的视频,每帧分辨率很低,只有 2x2 patches。我们来看 MRoPE 如何为这些 patches 编码。
- 视频信息: FPS=2 → 每帧间隔 0.5 秒。
- 第 0 帧 (t=0.0s) 的 patches 位置 ID:
(t=0.0, h=0, w=0)
(t=0.0, h=0, w=1)
(t=0.0, h=1, w=0)
(t=0.0, h=1, w=1)
- 第 1 帧 (t=0.5s) 的 patches 位置 ID:
(t=0.5, h=0, w=0)
(t=0.5, h=0, w=1)
(t=0.5, h=1, w=0)
(t=0.5, h=1, w=1)
计算示例:计算第1帧左上角 patch (t=0.5, h=0, w=0)
与第0帧所有 patches 的注意力。
- 与第0帧自身
(0.0, 0, 0)
比较:- 相对位置:
(Δt=0.5, Δh=0, Δw=0)
- 模型能感知到这两个 patch 在空间上是同一个位置,但时间上相差 0.5 秒。
- 相对位置:
- 与第0帧右边 patch
(0.0, 0, 1)
比较:- 相对位置:
(Δt=0.5, Δh=0, Δw=-1)
- 模型能感知到时间差和水平方向上的偏移。
- 相对位置:
- 与一个 1FPS 视频比较:如果一个视频是 1FPS,它的第1帧 temporal ID 是
t=1.0
。- 相对位置
(Δt=1.0, Δh=0, Δw=0)
的旋转角度差是 2FPS 视频(Δt=0.5, ...)
的两倍。 - 模型因此能本能地区分出“过了半秒”和“过了一秒”的本质区别,从而理解事件的节奏。
- 相对位置
2.2.3. 代码分析
让我们深入 Qwen2.5 的代码库来验证上述过程。
a) 位置ID的构建 (modeling_qwen.py
)
代码位置: qwen2_vision_modeling_qwen.py
或类似的模型文件中,在 forward
方法里构建位置ID。
# 伪代码,展示逻辑
def forward(self, hidden_states, images=None, past_key_values=None, ...):# ... 前面的处理 ...# 1. 构建序列的 position_ids# 对于纯文本部分,使用标准的 1D 位置ID [0, 1, 2, ..., seq_len-1]if position_ids is None:position_ids = torch.arange(seq_len, dtype=torch.long, device=device)position_ids = position_ids.unsqueeze(0)# 2. 如果是多模态输入(有图像/视频),需要构建复杂的 position_idsif images is not None:# 调用一个特定的函数来为图像/视频 patches 生成 (t, h, w) IDvision_position_ids = self._build_vision_position_ids(images)# 将文本的 position_ids 和视觉的 vision_position_ids 拼接起来position_ids = torch.cat([text_position_ids, vision_position_ids], dim=1)# ... 将 position_ids 传入后续的 RoPE 应用层 ...
关键点:位置ID不再是一个简单的一维张量,而是根据输入模态动态构建的。
b) MRoPE 的应用 (rotary_emb.py
)
代码位置: rotary_emb.py
中的 apply_rotary_pos_emb
函数或 Qwen2VisionRMSNormRotaryEmbedding
类。
这是应用旋转位置编码的核心函数。标准的 apply_rotary_pos_emb
可能被重写以处理多维位置信息。
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):"""q, k: 查询和键向量 [batch_size, seq_len, num_heads, head_dim]cos, sin: 预先计算好的余弦和正弦值position_ids: 位置ID,现在可能是多维的 [batch_size, seq_len, 3] (t, h, w)"""# 1. Gather the cos and sin values for the corresponding positions# 根据 position_ids 索引,获取每个token在三个维度上的cos/sincos_t = cos[position_ids[..., 0]] # 时间维度的cossin_t = sin[position_ids[..., 0]]cos_h = cos[position_ids[..., 1]] # 高度维度的cossin_h = sin[position_ids[..., 1]]cos_w = cos[position_ids[..., 2]] # 宽度维度的cossin_w = sin[position_ids[..., 2]]# 2. 复合旋转:依次应用三个维度的旋转# 假设 q_embed 是原始的查询向量# 首先在时间维度上旋转q_embed = apply_single_axis_rotary(q, cos_t, sin_t, dim=‘time’)# 然后在高度维度上旋转q_embed = apply_single_axis_rotary(q_embed, cos_h, sin_h, dim=‘height’)# 最后在宽度维度上旋转q_embed = apply_single_axis_rotary(q_embed, cos_w, sin_w, dim=‘width’)# 对 k 进行同样的操作k_embed = apply_single_axis_rotary(k, cos_t, sin_t, dim=‘time’)k_embed = apply_single_axis_rotary(k_embed, cos_h, sin_h, dim=‘height’)k_embed = apply_single_axis_rotary(k_embed, cos_w, sin_w, dim=‘width’)return q_embed, k_embed# 单轴旋转的辅助函数 (概念性代码)
def apply_single_axis_rotary(x, cos, sin, dim):"""x: 输入向量cos, sin: 该维度上的旋转余弦和正弦值dim: 指定旋转应用的维度(对应向量中的哪一组分量)"""# 将 x 的最后一维(head_dim)分成两组,分别作为复数的实部和虚部x1, x2 = x.chunk(2, dim=-1)# 应用旋转公式: (x1 + i*x2) * (cos + i*sin) = (x1*cos - x2*sin) + i*(x1*sin + x2*cos)rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)return rotated_x
关键点:MRoPE 的实现本质上是对 Q/K 向量依次进行三个独立维度的旋转变换。
c) 绝对时间对齐的实现 (_build_vision_position_ids
)
代码位置: 在模型文件中寻找 _build_vision_position_ids
方法。这是实现“绝对时间对齐”的关键。
def _build_vision_position_ids(self, images):"""images: 输入图像/视频数据,可能包含帧时间戳信息returns: 位置ID张量 [batch, num_patches, 3]"""batch_size, num_frames, channels, height, width = images.shapenum_patches_per_frame = (height // self.patch_size) * (width // self.patch_size)# 1. 获取绝对时间戳(例如,从视频元数据中,或根据FPS计算)# 假设 images 对象有一个属性 `frame_timestamps` [batch, num_frames]# 例如,一个2FPS的视频,num_frames=2: [[0.0, 0.5]]absolute_timestamps = images.frame_timestamps # 2. 为每个帧内的每个 patch 生成空间坐标 (h, w)h_coords = torch.arange(0, height // self.patch_size, device=device)w_coords = torch.arange(0, width // self.patch_size, device=device)grid_h, grid_w = torch.meshgrid(h_coords, w_coords, indexing='ij')# 空间坐标在所有帧中重复spatial_positions = torch.stack([grid_h.flatten(), grid_w.flatten()], dim=-1) # [num_patches_per_frame, 2]spatial_positions = spatial_positions.repeat(batch_size, num_frames, 1) # [batch, num_frames, num_patches_per_frame, 2]# 3. 将绝对时间戳分配给每个帧的每一个patch# 将 absolute_timestamps 扩展维度以匹配 spatial_positionst_positions = absolute_timestamps[:, :, None, None] # [batch, num_frames, 1, 1]t_positions = t_positions.repeat(1, 1, num_patches_per_frame, 1) # [batch, num_frames, num_patches_per_frame, 1]t_positions = t_positions.squeeze(-1) # [batch, num_frames, num_patches_per_frame]# 4. 组合成最终的 (t, h, w) ID 张量# 将空间坐标从 [h, w] 两个值合并position_ids = torch.cat([t_positions.unsqueeze(-1), # t dimensionspatial_positions[..., 0:1], # h dimensionspatial_positions[..., 1:2] # w dimension], dim=-1) # [batch, num_frames, num_patches_per_frame, 3]# 5. 重塑为最终的序列格式 [batch, total_patches, 3]position_ids = position_ids.reshape(batch_size, -1, 3)return position_ids
关键点:t_positions
不再是简单的 [0, 1, 2, ...]
,而是从数据中获取或计算出的绝对时间值。这行代码是实现“绝对时间对齐”的灵魂所在。
总结
通过公式、例子和代码的三重剖析,我们可以看到 Qwen2.5-VL 的 MRoPE 机制是如何工作的:
- 公式上:它将 1D RoPE 优雅地扩展为 3D,通过对 Q/K 向量依次施加时间、高度、宽度三个维度的旋转变换,将相对位置信息编码到注意力计算中。
- 概念上:它让模型能同时理解“什么时候”、“在哪里”,并且通过使用绝对时间戳,让模型获得了感知“时间流逝速度”(节奏)的能力,这是相比前代的质的飞跃。
- 实现上:在代码中体现为:
- 构建包含
(t, h, w)
三元组的position_ids
(_build_vision_position_ids
)。 - 在 RoPE 计算层,根据这三个维度的ID分别查找并应用三次旋转 (
apply_rotary_pos_emb
)。 t
维度的值来源于绝对时间,而非帧序号。
- 构建包含
这种从底层注意力机制入手的设计,使得 Qwen2.5-VL 在视频理解任务上具备了强大的时空推理基础。
3. 窗口注意力为什么能显著降低计算复杂度(从平方降为线性)?
窗口注意力(Window Attention)
1. 问题根源:标准自注意力的计算复杂度
标准的多头自注意力(MSA)机制是 Transformer 的核心,但其计算和内存成本是序列长度的平方级 O(n²)
。
原因在于:注意力矩阵的计算需要每个 token 与其他所有 token 进行交互。
- 公式: Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V
Q
,K
,V
的形状都是[序列长度 (n), 特征维度 (d)]
QK^T
这一步会产生一个[n, n]
的矩阵(即注意力分数矩阵)。这个矩阵的每个元素都代表一个 token 对另一个 token 的关注程度。- 计算量: 计算
QK^T
需要n * n * d = n²d
次操作。当序列长度n
很大时(例如处理高分辨率图像或长视频),n²
会变得极其巨大,成为计算瓶颈。
举例:
假设一张图片被分成 n = 256 × 256 = 65,536
个 patch。标准注意力的注意力矩阵大小将是 65,536 × 65,536
,这约等于 42.9 亿个元素,这在当前硬件上是无法直接计算的。
2. 解决方案:窗口注意力(Window Attention)
窗口注意力的核心思想是:一个 token 不必关注所有其他 token,只需关注其周围一个局部窗口(Window)内的 token。这是一种强大的归纳偏置,在图像领域非常有效,因为像素(或patch)的相关性通常随距离增加而减弱。
具体做法:
- 将输入特征图均匀地划分为多个不重叠的(或重叠的)窗口。
- 只在每个窗口内部计算标准的多头自注意力。
- 各个窗口之间的计算是完全独立的,可以并行处理。
3. 复杂度分析:从平方(O(n²))到线性(O(n))
假设:
- 总 token 数:
n = H × W
(对于图像) - 每个窗口的 token 数:
M × M
(例如,M=14
或M=112
) - 窗口数量:
(H/M) × (W/M)
现在我们来计算复杂度:
-
每个窗口的计算复杂度:
- 每个窗口需要计算一个
[M², M²]
的注意力矩阵。 - 计算一个窗口的
QK^T
需要(M²) * (M²) * d = M⁴d
次操作。
- 每个窗口需要计算一个
-
总计算复杂度:
- 总窗口数为
(H/M) × (W/M) = n / M²
。 - 因此,总计算量为:
(n / M²) * (M⁴d) = n * M² * d
。
- 总窗口数为
结论:总计算复杂度变成了 O(n * M² * d)
。
M
(窗口大小)是一个固定的超参数,不与输入图像大小n
挂钩。d
(特征维度)也是一个固定值。- 所以,总计算量只与 token 总数
n
呈线性关系,即O(n)
。
继续上面的例子:
n = 65,536
个 patch。如果我们设置窗口大小 M = 112
。
- 标准注意力:计算量 ∝
65,536² ≈ 4.29e9
- 窗口注意力:计算量 ∝
65,536 * (112)² ≈ 65,536 * 12,544 ≈ 8.22e8
- 计算量降低了约 5.2 倍。如果图像更大,节省的计算量会更加惊人。
4. Qwen2.5-VL 中的代码实现
在 Qwen2.5-VL 的 Vision Transformer (ViT) 中,并非所有层都使用窗口注意力。报告指出:只有4层使用全局注意力,其余层使用窗口注意力。这是一种常用的设计,在保证模型具有全局建模能力的同时,极大地降低了计算成本。
我们可以在其 ViT 实现的代码中找到相关证据(代码位置通常在 models/vision_transformer.py
或类似文件中):
class Qwen2VisionAttention(nn.Module):def __init__(self, config, layer_idx: int):super().__init__()self.config = configself.layer_idx = layer_idx# ... 初始化Q, K, V投影层等 ...# 判断当前层是否使用窗口注意力# 根据技术报告,只有索引为 [7, 15, 23, 31] 的层使用全局注意力self.use_window_attention = layer_idx not in config.full_attention_block_indexesif self.use_window_attention:self.window_size = config.window_size # 例如 112else:self.window_size = None # 全局注意力def forward(self, hidden_states, attention_mask=None):batch_size, seq_len, dim = hidden_states.shape# 1. 投影得到Q, K, Vqkv = self.qkv_proj(hidden_states)# ... 重塑qkv为多头格式 ...# 2. 应用RoPE位置编码 (Qwen2.5-VL使用2D-RoPE)# ... 此处省略RoPE代码 ...# 3. 核心:计算注意力if self.use_window_attention:# ******** 窗口注意力路径 ********# 将序列重塑为图像格式 [batch, height, width, heads, head_dim]hidden_states = hidden_states.view(batch_size, self.num_heads, height, width, -1)# 使用PyTorch的fold/unfold或自定义函数进行窗口划分# 这里是一个概念性实现windows = window_partition(hidden_states, self.window_size) # [B*num_windows, window_size, window_size, heads, head_dim]# 计算窗口内的注意力attn_output = self._compute_attention_within_windows(windows)# 将窗口合并回完整特征图attn_output = window_reverse(attn_output, self.window_size, height, width)else:# ******** 全局注意力路径 (仅在第7,15,23,31层使用) ********# 使用标准注意力,计算整个序列的QK^T,复杂度为O(n²)attn_weights = torch.matmul(query, key.transpose(-1, -2)) # [batch, heads, seq_len, seq_len]attn_weights = attn_weights / math.sqrt(self.head_dim)if attention_mask is not None:attn_weights = attn_weights + attention_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_weights, value)# 4. 投影输出attn_output = self.out_proj(attn_output)return attn_outputdef _compute_attention_within_windows(self, windows):"""在一个窗口内计算标准自注意力"""# windows shape: [batch*num_windows, window_size*window_size, heads, head_dim]b_win, n_patches, n_heads, d_head = windows.shapewindows = windows.view(b_win, n_patches, -1) # 合并头和维度以便于计算# 投影得到Q, K, V (这里简化了,实际QKV投影可能在窗口划分前已完成)q = self.q_proj(windows)k = self.k_proj(windows)v = self.v_proj(windows)# 计算窗口内的注意力,复杂度为O( (window_size²)² )attn_weights = torch.matmul(q, k.transpose(-1, -2))attn_weights = attn_weights / math.sqrt(self.head_dim)attn_weights = nn.functional.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_weights, v)return attn_output.view(b_win, n_patches, n_heads, d_head)# 辅助函数:将特征图划分为窗口
def window_partition(x, window_size):"""Args:x: (B, H, W, C)window_size (int): 窗口大小Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows
关键代码解读:
self.use_window_attention
:根据层索引决定是否使用窗口注意力。window_partition
函数:负责将[B, H, W, C]
的特征图划分成多个[B*n_win, M, M, C]
的窗口。_compute_attention_within_windows
:在每个窗口内部执行标准的、计算复杂度为O(M⁴)
的自注意力计算。由于M
是固定的,每个窗口的计算量是常数。- 总计算量 = 窗口数量
(n / M²)
× 常数(M⁴)
=O(n * M²)
,即与 token 数n
成线性关系。
总结
特性 | 标准全局注意力 | 窗口注意力 |
---|---|---|
计算复杂度 | O(n²d) | O(nM²d) |
内存占用 | O(n²) | O(nM²) |
设计思想 | 每个 token 关注所有 token | 每个 token 只关注局部窗口内的 token |
优势 | 强大的全局建模能力 | 计算高效,适合高分辨率输入 |
劣势 | 计算和内存成本高昂 | 需要其他机制(如移位窗口、全局层)来促进窗口间通信 |
Qwen2.5-VL 通过混合使用少数几层全局注意力和多数层窗口注意力,在保证模型具备全局感知能力的前提下,成功地将其视觉编码器的计算复杂度从难以处理的 O(n²)
降低到了可接受的 O(n)
,这是其能够高效处理原生高分辨率图像和长视频的关键技术支柱。