当前位置: 首页 > news >正文

【论文阅读】MaskGIT: Masked Generative Image Transformer

CVPR 2022,一个用于高质量图像生成和编辑的Transformer架构。

现有方法的局限

  • GANs虽然能生成高质量图像,但训练不稳定(比如模式崩溃问题,导致样本多样性不够)。

  • 生成式Transformer受到NLP成功启发,将图像当作一维序列处理(如VQGAN),但实际上:

    • 图像不是天然的序列结构,“一维展平+顺序生成”不能充分利用图像的空间结构和上下文。

    • 序列长度长(如一个512x512图像可能需要几千个token),导致建模和生成效率低下。

    • Autoregressive解码(逐步生成)速度慢且不易并行(比如生成一张图要花30秒)。


2. MaskGIT的核心创新

MaskGIT提出了一套与BERT类似的 “掩码学习”(masked token prediction)思路用于图像生成:

a. 训练阶段

  • 把图像编码为离散的token。

  • 随机掩盖一些token,Transformer需根据未被掩盖的token来预测被遮住(masked)的位置的内容(类似BERT的MLM)。

  • 网络是双向自注意力(Bidirectional self-attention),可以同时用上下左右信息,增强空间建模。

b. 推理/生成阶段

区别传统方法(逐步生成),MaskGIT采用并行多步生成:

  1. 初始状态:所有位置为“未知”token(全mask)。

  2. 第一步:模型并行预测所有token,但只保留“置信度最高”的一部分。

  3. 后续步:继续掩盖“置信度低”的token、并行预测并更新直至全部token都确定。整个过程通常只需几步(如8步),远快于传统的数百步逐个生成。

解码速度优化:MaskGIT的解码是并行的,并且只用很少的迭代步,对于512x512图像,生成效率是autoregressive方法的64倍!

c. Bidirectional vs. Unidirectional

  • 传统生成式Transformer只看已生成的token,MaskGIT可用全局信息预测token,能更好生成与图像上下文相关的内容。

d. 掩码调度(Mask Scheduling)

  • 每一步掩盖的比例逐步减少,采用cosine schedule能提升生成质量(论文有消融实验对比)。


总结:

  • MaskGIT本质上是一种将BERT式掩码预测引入视觉生成领域的新范式。

  • 并行、迭代掩码生成机制 在速度和质量上都有巨大优势。

  • Bidirectional attention让图像空间建模更强,局部修改更灵活。

  • 不仅适合高质量生成,也能做复杂的图像编辑、补全、外推等任务,可谓通用型强。


  • Figure 1展示了MaskGIT已能做图像生成、对象替换、图像扩展等多种任务,且分辨率很高。

  • Figure 2对比了传统逐步生成与MaskGIT的并行掩码生成解码过程,8步搞定高分辨率图像,极大提升效率。

  • Figure 3:方法全流程 Pipeline

两阶段设计(跟 VQGAN、VQVAE 类似):

  1. Tokenizer(第一阶段)
    输入图像 → VQGAN/VQ-VAE 编码器 → 离散化成“视觉 token”(即 codebook 索引)。

    • 图像大小 H×W 压缩成 h×w(h=H/16, w=W/16)token grid。

    • 每个 token 是一个 codebook id(共有 1024 个 token,可视作“视觉单词”)。

  2. Bidirectional Transformer(第二阶段)

    • 训练目标是 MVTM(Masked Visual Token Modeling),即随机mask部分视觉token,让模型根据未masked token预测被mask掉的内容(类似BERT的MLM)。

    • 跟之前自回归序列生成不同,Transformer是双向注意力,可以同时看到左、右、上、下的内容来猜mask位置。


3.1 MVTM 训练

  • 假设原始token序列为:

    Y={yi​}i=1N​

    N表示token总数(h×w)。

  • mask向量 M={mi​}i=1N​ ,m_i=1表示该token被mask替换为特殊符号[MASK],m_i=0表示原token保留。

  • mask比例γ(r):
    从0到1采样一个比率,根据这个比率γ(r)×N随机挑token做mask。
    (训练时随机,推理时按进度递减,下文3.3会细讲)

  • 损失函数:
    只在mask位置计算交叉熵损失:

    Lmask​=−EY∼D​i:mi​=1∑​logp(yi​∣YM​)

    其中 YM​ 是mask后的token序列。

  • 核心差异:AR(自回归)只能看到过去token,预测未来token;MVTM可双向利用上下文(空间信息)来描绘mask位置。


 3.2 Iterative Decoding

这是MaskGIT能快64倍的关键。

传统AR解码:

  • 一次只生成1个token,要生成 N 个token就得跑 N 步 → 对高分辨率图非常慢(像32×32 token要 1024 步)。

MaskGIT解码(并行 & 多次迭代):

  • Step 0:所有token都mask,表示未知(初始画布是空的)。

  • 每次迭代:

    1. Predict:对所有“当前被mask”的位置同时预测概率分布(并行计算)。

    2. Sample:在mask位置采样token,并赋予置信度(softmax概率)。

      • 已确定位置置信度=1.0。

    3. Mask Schedule:根据mask比例γ决定下一步要mask多少个token。mask掉当前低置信度的token,其他保留。

    4. Repeat:重复直到所有token都不再mask。

  • T步完成整张图:论文实验中T=8即可搞定,比如:

    • 第1步:填1个token

    • 第2步:填更多

    • 第8步:全部token填满 → 完成。

  • 好处:

    • 每步并行预测所有token → 提速;

    • 中期可以修改错误 → 增强全局一致性;

    • 仿佛画画:先大框架,后细节润色。


3.3 Masking Design

  • mask调度函数 γ(r) 定义每个迭代的mask比例(r在0~1表示进度)。

  • 必须满足:

    1. 0 ≤ γ(r) ≤ 1;

    2. 随着 r(进度)增加,γ递减:γ(0)≈1(初始全mask),γ(1)≈0(最后无mask)。

  • 三类函数:

    1. Linear:每次减少相同数量token;

    2. Concave(凹型):先少填,后期快速增加 → 类似人类画画流程(先大概,再密集细化)。包括cosine、平方、立方、指数等;

    3. Convex(凸型):反过来,先大量确定token,后期少改动(逻辑不太符合图像细化规律)。

  • 结论(实验4.4):

    • Cosine最好(略胜平方),性能稳定、FID最低(质量最高);

    • 凹型函数普遍优于线性,线性优于凸型。


4 实验

4.2 类条件生成

(ImageNet 256×256, 512×512)

  • FID 同分辨率下直接碾压 VQGAN(6.18 vs 15.78 @ 256x256),速度快30~64倍。

  • 超越 BigGAN 在某些指标(尤其FID512x512,降到 7.32)。

  • 多样性(CAS、Recall)几乎新SOTA(在不损失precison的情况下Recall更高 → 输出更多样化)。

4.3 图像编辑能力

(无需改architecture)

  • 类条件编辑(Class-conditional editing):某个bounding box内换成另一类,外部背景保持。AR方法几乎做不到,MaskGIT天然支持。

  • Inpainting(补洞):Places2数据集下超过DeepFillv2、HiFill,接近CoModGAN。

  • Outpainting(外延): 超越Boundless、InfinityGAN等GAN方法,FID&IS更优,还能采多种seed生成多样扩展。

归根到底,这灵活性来自“mask+双向注意力”的并行预测能力。

4.4 Ablation Mask Scheduling

  • 凹型(Cosine/Square)普遍好 → 因为能训练难例,又符合“先粗后细”。

  • 存在“sweet spot”迭代步数,T太多可能反而降低多样性,T=8~12最好。


总结方法优势

  1. 速度优势: 迭代8步,速度比自回归快最高64倍。

  2. 质量优势: FID/IS/CAS全面超SOTA。

  3. 灵活编辑: 天然支持局部生成(编辑/补全/外延)。

  4. 简单训练: 无需GAN对抗,稳定。

  5. 可扩展性: 同一tokenizer+transformer可直接迁移分辨率/任务。

http://www.dtcms.com/a/390263.html

相关文章:

  • Maya绑定基础知识总结合集:父子关系和父子约束对比
  • 从假设检验到数据驱动决策:统计推断的技术实战与方法论深度拆解
  • 基于PyTorch Geometric的图神经网络预训练模型实现
  • UniTask在Webgl上报错的记录
  • 供应链场景下Oracle分库分表案例架构及核心代码分析
  • 【leetcode】59. 螺旋矩阵 II
  • Discord+ChatGPT?如何利用AI打造海外私域社群?
  • 概率论强化大观
  • 数据结构——单链表(c语言笔记)
  • 【系列文章】Linux系统中断的应用05-延迟工作
  • Cannot find module ‘@ohos/ohoszxing‘ 和安装ohoszxing 的第三方库
  • Intelligent parking
  • 【试题】数据安全管理员考试题目
  • linux中的redis
  • 工作笔记-----stm32随机数发生器RNG配置问题
  • SQL中NTILE函数的用法详解
  • Rokid乐奇成为国礼的秘密,是握住美好数字生活的定义权
  • 基于 3D 高斯泼溅的重建 (3DGS-based)
  • Gin 集成 Redis:从连接到实战
  • python-asyncio与事件循环(Event Loop)
  • 100道经典C语言笔试题(前15道)
  • MySQL Binlog 实时监控与数据插入示例
  • SQL Server运维实战指南:从监控到优化的完整路径
  • 2025年09月18日Github流行趋势
  • AI审核工程文档:某电力企业如何把返工率从20%降到0?
  • 单北斗GNSS变形监测系统是什么?主要应用于哪些领域?
  • GELU:高斯与Dropout的完美融合
  • 实习智能家居网关项目总结(MQTT与云端连接通信)
  • 简单的随机抽奖算法
  • 开源免费的天若OCR复活版,翻译功能大升级