Show-o 论文解读
目录
一、Show-o
1、概述
2、方法
3、训练
一、Show-o
1、概述
视觉理解模型的发展,从单一的视觉理解,单一的图像生成,朝着视觉理解与生成统一的方向发展。但是以往的统一模型,要么是通过ViT+LLM,并将特征信息传输给diffusion用于生成图像(NExT-GPT,SEED-X),要么是通过tokenizer+LLM+de-tokenizer的方式(Chameleon),归根结底,都不是一个完整的Transformer架构。Show-o提出利用MAGVIT的分词器(本质上就是MaskGIT),实现单个Transformer同时处理理解和图像生成任务。
但随之而来存在一个问题,文本是一个离散的tokens,图像则是一个连续的tokens,二者本身存在明显差异,也不容易集成到同一个网络中。同样以往的方法都是将文本利用text encoder后直接用LLM编码,图像则需要进入扩散模型中。
Show-o为满足同时处理理解和生成任务,使用AR+diffusion混合建模,文本部分完全建立在以往LLM分词器上,保留文本推理的自回归建模能力。图像部分则采用MAGViT-v2,将图像离散化为256个token,实现与文本token的符号空间对齐。
2、方法
受益于离散去噪扩散模型(D3PMs),区别于传统扩散模型只能用于连续信息,离散去噪扩散模型可以处理离散数据(文本)间的信息,比如VQDiffusion,Copliot4D,而MaskGIT继续简化模型,并应用到图像离散化数据中,Show-o则是建立在MAGVIT-v2上。
Image Tokenization
利用MAGVIT-v2作为基础框架,训练一个无查找量化器,避免传统VQ-VAE的码本查询瓶颈。codebook size=8192,每张图片256x256被编码为16x16的离散tokens。由于MAGVIT-v2易于微调,所以未来将考虑衍生一个video tokenizer。(但是MAGVIT本身就是一个视频编码器啊,估计做了统一处理?),对于这个Image Tokenizer的架构,具体来说就是下图a,而b,c则是后续实验进行了对比。
Text tokenization
Show-o基于预训练LLaMA,使用相同的tokenizer进行文本数据标记,不做修改。
LLM整体架构
基于预训练LLM LLaMA设计,保留原始的Transformer结构,但是在每一个注意力层都添加QK-Norm操作,并新增8192个可学习嵌入向量,表示离散图像tokens。
统一提示策略
为了统一训练多模态理解和生成,设计了Unified Prompting 策略,对给定Image-text pair 通过tokenizer得到M个image tokens 和N个text tokens
。
并且根据下图的方法,设计为multi-modal understanding(多模态理解),visual generation(文生图),mixed-modality generation(混合模态生成)三种任务,其中右侧的 [MMU] 和 [T2I] 代表预定义的task token,表示执行什么具体的任务(生成文字or生成图片), [SOT] 和 [EOT] 代表text token的开始和结束token,[SOI] 和 [EOI] 代表image token的开始和结束token。
Omni-Attention机制
对于Show-o注意力机制并不是Casual attention,也不是Full attention,而是一种全新的综合注意力机制,根据输入序列的格式,自适应地混合和更改。可以理解为在不同Image和Text混合下,Casual attention和Full attention范围内的一种自适应变换。
其中Show-o通过Casual attention对sequence中的text tokens进行建模,通过Full attention对image tokens进行建模。
所以鉴于上面的统一提示策略图,提出了四种任务的注意力机制变换。
(a)多模态理解:文本关注先前所有图像token,但是文本之间只关注以前的文本token
(b)文生图:图像token可以交互所有先前文本token,但是图像间互相全交互
(c)文本建模中:退化会casual attention
(d)混合模态生成:综合以上多种方法自适应调整。
3、训练
训练目标
训练目标包含LLaMA本身的自回归(Next-token-prediction)用于处理文本的语言建模损失,以及图像离散扩散建模的扩散损失(Mask-token-prediction)。
对于给定M个image tokens 和N个text tokens
NTP:
MTP:对于输入的M个Image tokens ,首先以一定的比例(受 timestep控制)随机将图像token随机替换为[MASK] token,得到
,然后目标以unmasked区域和text token,重建原始图像的token。
基于classifier-free guidance做法,以一定的概率用空文本随机替换conditioned text token。
总损失为
训练策略
训练分为三个阶段,由于缺乏了文本编码器模块,这对于文本与图像对齐产生了很大挑战,所以我们采用三阶段的方法。
第一阶段,训练图像token嵌入(8192个新增向量)和像素依赖学习,通过纯文本RefinedWeb训练语言建模能力,图像分类库ImageNet-1K训练图像生成能力,图文对CC12M+SA1B训练基础图文对齐。
第二阶段:跨模态深度对齐,将ImageNet的分类名,转为自然语言描述训练文本对齐能力,文本描述能力。
第三阶段:高质量数据微调。利用高质量图文对LAION-aesthetics-12M,JourneyDB,训练文生图,另外通过LLaVA-Pretain-558K和LLaVA-v1.5-mix-665K训练复杂推理指令和多任务混合指令。
推理策略
对于文本的预测,直接给定图像或多模态问题,text token从具有更高置信度的预测token中自回归采样。
对于图像的预测,通过输入文本信息(N个token),和M个token [MASK]作为输入,通过show-o为每一个[MASK] token预测一个logit ,其中t是时间步,每个[MASK]token的最终预测logit使用conditional logit
和masked token的unconditional logit
。
,其中w是guidance scale
下图为去噪过程,包含T步,其中每一步保留置信度更高的image token,并替换以往的[MASK] token,随后反馈到下一轮预测。
论文参考:[2408.12528] Show-o: One Single Transformer to Unify Multimodal Understanding and Generation