当前位置: 首页 > news >正文

ViT学习

图像块嵌入(Patch Embeddings)

为了将Transformer应用到2D图像中,ViT将图像\mathbb{R}^{H\times W\times C}展平成2D的patches矩阵\mathbb{R}^{N\times (P^{2}\cdot C)},其中H、W分别为图像的height(高)和width(宽),C是原图像通道数,(P,P)是每个图像patch的分辨率,N为图像的patch数(N=H\times W/P^{2})。所以原图转化成了一组一维的patch序列。

P^{2}\cdot C的结果可能很大,Transformer不能直接处理高维的向量,因此我们定义一个可训练的权重矩阵E\in \mathbb{R}^{(P^{2}\cdot C)\times D},把每个patch向量x_{p}映射到D维空间:

x_{p}E\epsilon R^{N\times D}

这一步就是Patch Embedding,把每一个patch向量映射到一个固定的维度中。

最终输入给Transformer的就是N个D维向量。整个过程类似于把文本的one-hot encoding映射为固定维度词嵌入的过程。

可学习的分类Token(Class Token)

在得到N个patch向量输入后,Transformer会输出N个向量,这样到底用哪一个来代表整张图像做分类呢?类似于NLP里BERT在句子最前面加一个[CLS]token专门用来承载整句话的语义,在ViT里,我们在patch序列的最前面手动添加一个可学习的嵌入向量x_{class}。所以输入序列变成:

x_{class},x_{p}^{1},x_{p}^{2}, ... , x_{p}^{N}

序列长度变为N+1。

这个x_{class}一开始是随机初始化的,在自注意力机制中,class token会作为Query,和所有patch的Key/Value交互;它通过注意力权重从所有 patch 中获取信息,并在多层 Transformer 中不断更新,最终成为图像的全局语义表示。

在分类头中,我们便通过这个Class Token做分类任务。

位置嵌入(Position Embeddings)

Transformer的自注意力机制有一个特点:扰动不变性(Permutation-invariant),即如果我们打乱输入序列中token的顺序,self-attention本身不会知道顺序被打乱了。所以在Transformer中采用了固定位置编码。在ViT论文中,提供了四种方案:

1. 无位置嵌入:效果最差

2. 1D位置嵌入(1D-PE):把图像当作一维序列,每个位置分配一个可学习的向量

3. 2D位置嵌入(2D-PE):用图像块的二维坐标 (x,y) 来生成位置向量,更贴近图像的空间结构

4. 相对位置嵌入(RPE):考虑token与token之间相对关系

2/3/4方案效果相差不大,因为patch本身就包含了一部分局部空间结构信息。

Transformer Encoder

详见:Transformer

一个Transformer Encoder Block主要包含一个MSA(多头自注意力层)和一个MLP/FFN(多层感知机/Feed-forward network),MLP/FFN包含两个FC(全连接层)+中间的非线性激活函数GeLU。层内通过残差结构和层归一化层连接。

推理过程

1. 输入图像,划分patch:

224\times 224\times 3\rightarrow 196\times (16\times 16\times 3)

2. 通过线性投影层映射到维度为D的嵌入空间,得到patch embedding序列:

z_{1}^{0},z_{2}^{0}, ... ,z_{N}^{0}

每个向量长度为D

3. 加入class token,序列长变为N+1:

z_{cls}^{0}, z_{1}^{0},z_{2}^{0}, ... ,z_{N}^{0}

4. 加入位置嵌入,默认使用1D绝对位置嵌入:

z_{i}^{0}=z_{i}^{0}+E_{pos}[i], i=0, ... , N

5. 经过L层Transformer编码器

(1)多头注意力子层(MSA),所有token两两交互(多组权重矩阵):

Q_{i}=z_{i}W_{Q}, K_{i}=z_{i}W_{K}, V_{i}=z_{i}W_{V}

Attn(i,j)=\frac{Q_{i}\cdot K_{j}}{\sqrt{D_{k}}}

z_{i}=\sum_{j}^{}softmax_{j}(Attn(i,j))V_{j}

(2)前馈网络子层(MLP):每个token独立通过两层全连接+GeLU激活:

MLP(z) = W_{2}GeLU(W_{1}z+b_{1})+b_{2}

(3)残差连接+LayerNorm:每个子层前先对输入做LayerNorm,再加上残差

重复L层后,得到最终序列表示

6. 分类

得到多层交互后的class token,此时该向量已融合全图信息,把z_{cls}^{L}送入分类头,输出每个类别的概率分布。

归纳偏置(Inductive Bias)与优化

归纳偏置:模型在设计时人为注入的先验知识,帮助它更快或更高效地学习。

CNN有以下特点:

1. 局部性(Locality):卷积核只在局部区域滑动,所以先关注小范围特征

2. 二维邻域结构 (2D neighborhood structure):卷积核在二维空间移动,天然利用了图像的二维网格结构

3. 平移等变性 (translation equivariance):输入图像发生平移,卷积后的特征图也会以同样的方式平移

以上特点使CNN在处理图像时已经帮模型带来了很多图像结构的认识。

而在ViT中唯一和图像局部相关的设计是一开始的patch划分。之后所有空间关系都要依靠训练学习出来。位置嵌入在初始化时也不知道任何2D空间结构,也是依靠训练逐渐学习的。换句话说,ViT几乎完全按照数据来学习图像结构;CNN则在网络结构里就内置了很多图像知识。

为了弥补ViT的归纳偏置不足,可以考虑 CNN+ViT 的混合方式。即不直接把原始图像切成 patch,而是先用CNN提取特征图;然后把CNN的特征图按照ViT的方式做处理。这样,CNN 已经做了局部感知和降采样,减少了 Transformer 的学习负担,同时ViT可以在CNN特征的基础上学习全局依赖。

微调

通常,我们在大型数据集上预训练 ViT,并对 (更小的) 下游任务进行微调。


文章转载自:

http://sysrARSO.qsfys.cn
http://mWNtPBjd.qsfys.cn
http://2jUzGBqW.qsfys.cn
http://EZWI9RIr.qsfys.cn
http://3PWknN0k.qsfys.cn
http://g4Fh8t0l.qsfys.cn
http://plCDJPW0.qsfys.cn
http://DIHdEs7K.qsfys.cn
http://2Qn2jHWQ.qsfys.cn
http://Z8yO5fk6.qsfys.cn
http://hzPELVC6.qsfys.cn
http://ehkoJ9rW.qsfys.cn
http://AZobeLaE.qsfys.cn
http://884ivjah.qsfys.cn
http://ptzUGcLT.qsfys.cn
http://dUteDlaa.qsfys.cn
http://Yoa052wf.qsfys.cn
http://GcBwk9uD.qsfys.cn
http://iH32L49l.qsfys.cn
http://FveU4LAZ.qsfys.cn
http://EPrtOIVL.qsfys.cn
http://ze4ZFLRW.qsfys.cn
http://4eX5d0m7.qsfys.cn
http://cYqB25G2.qsfys.cn
http://Vcbx48p9.qsfys.cn
http://XVW8wC28.qsfys.cn
http://e1IeMTse.qsfys.cn
http://7IWzKzNT.qsfys.cn
http://KHSdo3QD.qsfys.cn
http://6vKGf7ro.qsfys.cn
http://www.dtcms.com/a/372376.html

相关文章:

  • 【Java实战㉚】深入MyBatis:从动态SQL到缓存机制的进阶之旅
  • 腾讯云EdgeOne免费套餐:零成本开启网站加速与安全防护
  • Cookie-Session 认证模式与Token认证模式
  • Redis哨兵模式在Spring Boot项目中的使用与实践
  • [工作表控件13] 签名控件在合同审批中的应用
  • 【图像理解进阶】MobileViT-v3核心技术解析和应用场景说明
  • 前端拖拽功能实现全攻略
  • AI赋能软件开发|智能化编程实战与未来机会有哪些?
  • 335章:使用Scrapy框架构建分布式爬虫
  • Docker|“ssh: connect to host xxx.xxx.xxx.xxx port 8000: Connection refused“问题解决
  • OneCode 可视化揭秘系列(三):AI MCP驱动的智能工作流逻辑编排
  • 数据结构深度解析:二叉树的基本原理
  • Supabase02-速通
  • LLM学习:大模型基础——视觉大模型以及autodl使用
  • 嵌入式Secure Boot安全启动详解
  • 【倍增】P3901 数列找不同|普及+
  • 数据结构:堆
  • 继续优化基于树状数组的cuda前缀和
  • 数组常见算法
  • 数仓建模理论
  • 致远A8V5 9.0授权文件
  • 【New Phytologist】​​单细胞多组学揭示根毛对盐胁迫的特异性响应文献分享
  • MyBatis 拦截器让搞定监控、脱敏和权限控制
  • 20250907-0101:LangChain 核心价值补充
  • 论CMD、.NET、PowerShell、cmdlet四者关系
  • 从IFA展会看MOVA的“全维进阶”如何重新定义智能家居边界
  • SpringBoot 数据脱敏实战: 构建企业级敏感信息保护体系
  • 公链分析报告 - 模块化区块链1
  • 20250907-01:理解 LangChain 是什么 为什么诞生
  • 做一个鉴权系统