视觉_transform
visual_transform
图像分块 (Patch Embedding)
-
假设输入图像为 x ∈ R ∗ H ∗ × ∗ W ∗ × ∗ C ∗ x∈R^{*H*×*W*×*C*} x∈R∗H∗×∗W∗×∗C∗
-
C 是图像的通道数(例如,RGB图像的 C=3)
-
将图像分割成N个大小为P*CP的patch,每个patch的大小为 P × P × C P×P×C P×P×C
-
N = H ∗ W P 2 N = \frac{H*W}{P^2} N=P2H∗W
-
-
将每个patch展平为一个向量,展平后的向量长度为
P 2 ∗ C P^2 * C P2∗C -
将每个展平后的patch向量通过一个线性投影(全连接层)映射到一个 D维的嵌入空间
-
这个线性投影是可学习的,其权重矩阵为 E ∈ R ( P 2 ∗ C ) ∗ D E∈R^{(P^2*C)*D} E∈R(P2∗C)∗D
-
公式表示为:
X p 是 R ( P 2 ∗ C ) 展平后的 p a t c h 向量 Z p = X p ∗ E X_p是R^{(P^2*C)}展平后的patch向量\\ Z_p = X_p*E Xp是R(P2∗C)展平后的patch向量Zp=Xp∗E
-
-
将所有patch嵌入向量 Z p Z_p Zp 按顺序排列,形成一个序列:
z = [ z 1 , z 2 , . . . , z n ] z=[z_1, z_2, ...,z_n] z=[z1,z2,...,zn] -
小结
- 将图像分割成固定大小的patch,每个patch被视为一个“单词”
- patch的大小和数量决定了模型的输入序列长度
- 将每个patch展平并通过线性投影映射到高维空间,形成patch嵌入向量
- 线性投影的权重是可学习的,模型通过训练优化这些权重
- 将所有patch嵌入向量按顺序排列,形成一个序列,作为Transformer编码器的输入
位置编码(Positional Encoding)
-
在ViT中,位置编码是一个可学习的向量,与patch嵌入向量的维度相同(即 D 维),j假设有N个patch,则位置编码矩阵为:
E p o s ∈ R ( N + 1 ) ∗ D E_pos ∈R^{(N+1)*D} Epos∈R(N+1)∗D- N 是patch的数量
- D 是patch嵌入向量的维度
- 额外的 +1 是为了处理[class] token(分类token)
-
将位置编码添加到patch嵌入向量中:
z = [ z 1 , z 2 , . . . , z n ] + E p o s z=[z_1, z_2, ...,z_n] + E_{pos} z=[z1,z2,...,zn]+Epos -
在ViT中,通常会在patch序列的开头添加一个额外的[class] token(分类token),用于最终的分类任务。这个[class] token也会被赋予一个位置编码
Swing transformer
- 传统transformer中,自注意力机制(Self-Attention)的计算复杂度为 O(N^2),N为输入长度;
- 自注意力机制首先会计算序列中每个字对所有其他字的注意力分数
- 假设序列长度为N,那么对于序列中的每个字,我们都需要计算它与序列中其他N-1个字的注意力分数,这是一个N×(N-1)
- 对于高分辨率图像,输入序列长度 N会非常大(eg:224*224的图像展平后序列长度为50176)
- Swin Transformer 提出了滑动窗口机制,将自注意力计算限制在局部窗口内,从而将计算复杂度从O(N2)降低到O(M2 * N)
- 通过窗口移位(Window Shift),使不同窗口之间能够交互信息,从而捕捉全局上下文
eg:
-
假设有一个输入特征图,大小为 4x4
A1 A2 A3 A4 B1 B2 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4
-
选择窗口大小为 2x2,那么可以将特征图划分为 4 个不重叠的窗口
窗口 1: A1 A2 窗口 2: A3 A4 B1 B2 B3 B4 窗口 3: C1 C2 窗口 4: C3 C4 D1 D2 D3 D4
在每个窗口内,模型会计算 自注意力(Self-Attention),而不是在整个特征图上计算。这样可以显著降低计算复杂度
例如,在 窗口 1 中,模型会计算 A1、A2、B1、B2 之间的自注意力关系
-
Swin Transformer 引入了 窗口移位 机制。具体来说,窗口会向右下角移动 1 个 patch(即窗口大小的一半),然后重新划分窗口
A1 A2 A3 A4 B1 B2 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4
-
移位后,窗口重新划分为 4 个新的窗口
窗口 1: A2 A3 窗口 2: A4 B1 B2 B3 B4 C1 窗口 3: C2 C3 窗口 4: C4 D1 D2 D3 D4 D2
在移位后的窗口中,模型会再次计算自注意力。例如,在 窗口 1 中,模型会计算 A2、A3、B2、B3 之间的自注意力关系
窗口还原和特征融合
- 每次窗口移位后,模型会计算一次 移位窗口自注意力(Shifted Window Multi-Head Self-Attention, SW-MSA);会产生多个注意力分数,通过 窗口还原(Window Reversal) 和 特征融合 来实现
- 每次窗口移位后,模型会计算一次自注意力,得到新的特征表示
- 在计算完移位窗口的自注意力后,模型需要将特征图从移位后的窗口还原回原始布局
- 例如,假设窗口大小为 2x2,窗口移位后,特征图被重新划分为新的窗口。计算完 SW-MSA 后,模型会将特征图还原回原始的 4x4 布局
- 在 Swin Transformer 中,W-MSA 和 SW-MSA 是交替使用的。每个 Swin Transformer 块(Block)包含一个 W-MSA 和一个 SW-MSA
- W-MSA 和 SW-MSA 的输出特征图会通过 残差连接(Residual Connection) 进行融合
层次化设计
-
Patch Merging:将相邻的 patch 合并,下采样特征图
- 假设输入图像大小为 224x224,patch 大小为 4x4
- 输入特征图大小:56x56(224/4)
- 每个阶段中,包含多个 Swin Transformer 块,每个块由 W-MSA(Window Multi-Head Self-Attention, W-MSA) 和 SW-MSA 组成
-
在每个阶段的开始,通过 Patch Merging 将特征图下采样
- 经过 Patch Merging 后,特征图大小变为 28x28,通道数增加 4 倍
- 包含多个 Swin Transformer 块
-
经过 Patch Merging 后,特征图大小变为 14x14,通道数增加 4 倍
- 包含多个 Swin Transformer 块
-
经过 Patch Merging 后,特征图大小变为 7x7,通道数增加 4 倍
- 包含多个 Swin Transformer 块
-
输入特征图大小:7x7
- 不进行 Patch Merging,直接包含多个 Swin Transformer 块
-
经过 Patch Merging 后,特征图大小变为 7x7,通道数增加 4 倍
- 包含多个 Swin Transformer 块
-
输入特征图大小:7x7
- 不进行 Patch Merging,直接包含多个 Swin Transformer 块
-