当前位置: 首页 > news >正文

论文阅读笔记——FLOW MATCHING FOR GENERATIVE MODELING

Flow Matching 论文
扩散模型:根据中心极限定理,对原始图像不断加高斯噪声,最终将原始信号破坏为近似的标准正态分布。这其中每一步都构造为条件高斯分布,形成离散的马尔科夫链。再通过逐步去噪得到原始图像。
Flow matching 采取直接将已知分布(如白噪声)转换为真实数据分布来生成数据,并且 Flow 是基于 Normalizing Flow,故而是可微双射。生成过程中变化的概率密度构成一个集合,称为概率密度路径 p t p_t pt ,T 为路径长度。初始数据 x 0 ∼ p 0 ( x 0 ) x_0 \sim p_0(x_0) x0p0(x0),目标数据 x T ∼ p T ( x T ) x_T \sim p_T(x_T) xTpT(xT)
x 0 x_0 x0 x T x_T xT 的过程可以表示为: x T = ϕ ( x 0 ) = ϕ T ∘ ⋯ ∘ ϕ t + 1 ∘ ϕ t ∘ ⋯ ϕ 1 ( x 0 ) x_T=\phi(x_0)=\phi_T\circ\cdots\circ\phi_{t+1}\circ\phi_t\circ\cdots\phi_1(x_0) xT=ϕ(x0)=ϕTϕt+1ϕtϕ1(x0)
且对中间任意时间步 x t x_t xt 有:
x t = ϕ t ( x t − 1 ) x t − 1 = ϕ t − 1 ( x t ) \begin{aligned} x_t=\phi_t(x_{t-1}) \\x_{t-1}=\phi_t^{-1}(x_t) \end{aligned} xt=ϕt(xt1)xt1=ϕt1(xt)
根据概率密度函数的变量变换关系可得:(行列式为时刻 t 对应的流 ϕ t \phi_t ϕt 的 Jacobian 行列式)
p t ( x t ) = p t − 1 ( x t − 1 ) d e t [ ∂ x t − 1 ∂ x t ] = p t − 1 ( ϕ t − 1 ( x t ) ) d e t [ ∂ ϕ t − 1 ∂ x t ( x t ) ] ( 1 ) \begin{aligned} p_t(x_t) & =p_{t-1}(x_{t-1})\mathrm{det}\left[\frac{\partial x_{t-1}}{\partial x_t}\right] \\ & =p_{t-1}(\phi_t^{-1}(x_t))\mathrm{det}\left[\frac{\partial\phi_t^{-1}}{\partial x_t}(x_t)\right] \qquad \qquad (1) \end{aligned} pt(xt)=pt1(xt1)det[xtxt1]=pt1(ϕt1(xt))det[xtϕt1(xt)](1)
那么就可以从初始数据分布 p 0 p_0 p0 推导到目标数据分布 p T p_T pT。行列式的本质是空间缩放的度量,相当于每次变换时都对概率密度进行归一化,采用更简洁的前推方程为:
p t = [ ϕ t ] ∗ p 0 p_t=[\phi_t]_{*}p_0 pt=[ϕt]p0

向量场建模

这可以通过 Neural Ordinary Differential Equations(NODE) 对 Jacobian 行列式中的常微分方程(ODE)建模,求出 ϕ t \phi_t ϕt。为了实现这一点,需要将离散的时间步 t = { t i } i = 1 T t=\{t_i\}_{i=1}^T t={ti}i=1T 映射到连续时间变量 t ∈ [ 0 , 1 ] t\in[0,1] t[0,1] ,这样将 p t p_t pt 定义为连续时间和数据点的笛卡尔积: p : [ 0 , 1 ] × R d − > R > 0 p:[0,1] × \mathbb{R}^d -> \mathbb{R}_{>0} p:[0,1]×Rd>R>0 ∫ p t ( x ) d x = 1 \int p_t(x)dx=1 pt(x)dx=1,这就是 CNF 建模。
虽然 CNF 对 Flow 进行了建模,但同时也面临 Jacobian 行列式计算性能低下、训练时需要进行模拟、难以高效采样的问题。Flow Matching 通过回归概率路径的向量场,在训练时规避了复杂计算,并具有更高的采样性能,而且向量场基本上是完全精确的,与扩散模型使用的变分推断等近似方法相比在似然计算上更具优势。而重新审视 Flow,数据点在时间上的变换可以用 Flow 的梯度表示,构成了关于时间的向量场
d d t ϕ t ( x ) = v t ( ϕ t ( x ) ) ( 2 ) ϕ 0 ( x ) = x \begin{aligned} \frac{d}{dt}\phi_t(x) &= v_t(\phi_t(x)) \qquad \qquad (2)\\ \phi_0(x) &= x \end{aligned} dtdϕt(x)ϕ0(x)=vt(ϕt(x))(2)=x
对于给定向量场 ϕ t \phi_t ϕt 满足式(1)时,可以说向量场 v t v_t vt 生成概率密度路径 p t p_t pt,类比物理学的连续性方程,则有:
d d t p t + ∇ ⋅ v t p t = 0 ( 3 ) \frac{d}{dt}p_t+\nabla ·v_tp_t=0 \qquad \qquad (3) dtdpt+vtpt=0(3)
选取合适的 ODE 求解器,那么可以得到目标函数:
L F M ( θ ) = E t , p t ( x ) ∣ ∣ v t ( x ) − u t ( x ) ∣ ∣ 2 ( 4 ) \mathcal{L}_{FM}(\theta)=\mathbb{E}_{t,p_t(x)}||v_t(x)-u_t(x)||^2 \qquad \qquad (4) LFM(θ)=Et,pt(x)∣∣vt(x)ut(x)2(4)

条件流匹配

尽管我们已经推导出了 L F M \mathcal{L}_{FM} LFM 但是目标向量 u t u_t ut 还是未知的,无法学习,可以通过用易于访问的混合分布来构造真实概率路径。通过目标数据样本 x 1 ∼ q ( x 1 ) x_1 \sim q(x_1) x1q(x1) 定义一个条件概率 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1),使得 p 0 ( x ∣ x 1 ) = p ( x ) p_0(x|x_1)=p(x) p0(xx1)=p(x) 。用这个条件概率和真实分布 q ( x 1 ) q(x_1) q(x1) 边缘化 p t p_t pt 有:
p t ( x ) = ∫ p t ( x ∣ x 1 ) q ( x 1 ) d x 1 ( 5 ) p_t(x)=\int p_t(x|x_1)q(x_1)dx_1 \qquad \qquad (5) pt(x)=pt(xx1)q(x1)dx1(5)
同样“边缘化”向量场 u t u_t ut,有:
u ( x ) = ∫ u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) q ( x 1 ) p t ( x ) ( 6 ) u(x)=\int u_t(x|x_1)\frac{p_t(x|x_1)q(x_1)}{p_t(x)} \qquad \qquad(6) u(x)=ut(xx1)pt(x)pt(xx1)q(x1)(6)
该公式是连接条件向量场和边缘向量场的桥梁,论文证明只要条件向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1) 能生成对应的条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1),对于任何分布 q ( x 1 ) q(x_1) q(x1)上述定义的边缘向量场 u t ( x ) u_t(x) ut(x) 能够生成对应的边缘概率路径 p t ( x ) p_t(x) pt(x)。并且证明了除与 θ \theta θ 无关的常数外, L C M \mathcal{L}_{CM} LCM L F M \mathcal{L}_{FM} LFM 相等,即 ∇ θ L C M ( θ ) = ∇ θ L C F M ( θ ) \nabla_\theta \mathcal{L}_{CM}(\theta)=\nabla_\theta\mathcal{L}_{CFM}(\theta) θLCM(θ)=θLCFM(θ)
由此论文提出了基于条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1)条件向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(xx1) 的目标函数:
L C F M ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ) − u t ( x ∣ x 1 ) ∥ 2 ( 7 ) \mathcal{L}_{\mathbf{CFM}}(\theta)=\mathbb{E}_{t,q(x_1),p_t(x|x_1)}\|v_t(x)-u_t(x|x_1)\|^2 \qquad \qquad(7) LCFM(θ)=Et,q(x1),pt(xx1)vt(x)ut(xx1)2(7)
Conditional Flow Matching可以选择任意的条件概率路径,只要满足边界条件即可,这里针对一般高斯条件概率路径 p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(x|x_1)=\mathcal{N}(x|\mu_t(x_1),\sigma_t(x_1)^2I) pt(xx1)=N(xμt(x1),σt(x1)2I)

  • 当 t=0 时, μ 0 ( x 1 ) = 0 σ 0 ( x 1 ) = 1 \mu_0(x_1)=0 \qquad \sigma_0(x_1)=1 μ0(x1)=0σ0(x1)=1 确保所有的条件概率路径都会收敛到相同的标准高斯分布。
  • 当 t=1 时, μ 1 ( x 1 ) = x 1 σ 1 ( x 1 ) = σ m i n \mu_1(x_1)=x_1 \qquad \sigma_1(x_1)=\sigma_{min} μ1(x1)=x1σ1(x1)=σmin x 1 x_1 x1 为中心的高斯分布。
    对于一个概率路径,存在无限多个向量场可以生成它,这里采用高斯分布的标准变换:
    ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) ( 8 ) \psi_t(x)=\sigma_t(x_1)x+\mu_t(x_1) \qquad \qquad (8) ψt(x)=σt(x1)x+μt(x1)(8)
    代入式(2)则有:
    d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{dt}\psi_t(x)=u_t(\psi_t(x)|x_1) dtdψt(x)=ut(ψt(x)x1)
    x 0 x_0 x0 重参数化 p t ( x ∣ x 1 ) p_t(x|x_1) pt(xx1) 代入式(7)则有:
    L C F M ( θ ) = E t , q ( x 1 ) , p 0 ( x 0 ) ∥ v t ( ψ t ( x 0 ) ) − d d t ψ t ( x 0 ) ∥ 2 ( 9 ) \mathcal{L}_{\mathbf{CFM}}(\theta)=\mathbb{E}_{t,q(x_1),p_0(x_0)}\|v_t(\psi_t(x_0))-\frac{d}{dt}\psi_t(x_0)\|^2 \qquad \qquad(9) LCFM(θ)=Et,q(x1),p0(x0)vt(ψt(x0))dtdψt(x0)2(9)
    根据式(8)可知 ϕ t \phi_t ϕt 是可逆的仿射变换,故而可以得到条件向量场为:
    u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) ( 10 ) u_t(x|x_1)=\frac{\sigma_t^{'}(x_1)}{\sigma_t(x_1)}(x-\mu_t(x_1))+\mu_t^{'}(x_1) \qquad \qquad (10) ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1)(10)

值得注意的是,这里选择高斯概率密度路径只是可选方式之一,实际上可以根据需要设计任何合理的路径,SD3 也是 Conditional Flow Match 的一种应用。

扩散模型

在这里插入图片描述

Variance Exploding(VE)

根据反向时间对称性,从噪声到数据的逆向过程中的条件 p t p_t pt 为:
p t ( x ) = N ( x ∣ x 1 , σ 1 − t 2 I ) p_t(x)=\mathcal{N}(x|x_1,\sigma_{1-t}^2I) pt(x)=N(xx1,σ1t2I)
σ 0 = 0 σ 1 > > 1 \sigma_0=0 \quad \sigma_1>>1 σ0=0σ1>>1。因此 μ t ( x 1 ) = x 1 \mu_t(x_1)=x_1 μt(x1)=x1 σ t ( x 1 ) = σ 1 − t \sigma_t(x_1)=\sigma_{1-t} σt(x1)=σ1t 代入式(10):
u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) ( 11 ) u_t(x|x_1)=-\frac{\sigma_{1-t}^{'}}{\sigma_{1-t}}(x-x_1) \qquad \qquad (11) ut(xx1)=σ1tσ1t(xx1)(11)
对于 VP-SDE (DDPM)的条件 p t p_t pt 有:
p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) , w h e r e α t = e − 1 2 T ( t ) , T ( t ) = ∫ 0 t β ( s ) d s p_t(x|x_1)=\mathcal{N}(x|\alpha_{1-t}x_1,(1-\alpha_{1-t}^2)I),\quad\mathrm{where~}\alpha_t=e^{-\frac{1}{2}T(t)},T(t)=\int_0^t\beta(s)ds pt(xx1)=N(xα1tx1,(1α1t2)I),where αt=e21T(t),T(t)=0tβ(s)ds
其中 β \beta β 是噪声尺度函数。因此: μ t ( x 1 ) = α 1 − t x 1 \mu_t(x_1)=\alpha_{1-t}x_1 μt(x1)=α1tx1 σ t ( x 1 ) = 1 − α 1 − t 2 \sigma_t(x_1)=\sqrt{1-\alpha_{1-t}^2} σt(x1)=1α1t2 代入式(10)得到 u t u_t ut 解析式:
u t ( x ∣ x 1 ) = α 1 − t ′ 1 − α 1 − t 2 ( α 1 − t x − x 1 ) = − T ′ ( 1 − t ) 2 [ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) ] u_t(x|x_1)=\frac{\alpha_{1-t}^{'}}{1-\alpha_{1-t}^2}(\alpha_{1-t}x-x_1)=-\frac{T^{'}(1-t)}{2}\left[e^{-T(1-t)}x-\frac{e^{-{\frac{1}2T(1-t)}}x_1}{1-e^{-T(1-t)}}\right] ut(xx1)=1α1t2α1t(α1txx1)=2T(1t)[eT(1t)x1eT(1t)e21T(1t)x1]

Optimal Transport

由于 Flow Matching 不依赖扩散过程,可以构建一个最优传输,将条件 p t p_t pt μ \mu μ σ \sigma σ 构建为简单的随时间的线性变换:
μ t ( x ) = t x 1 σ t ( x ) = 1 − ( 1 − σ m i n ) t \begin{aligned} \mu_t(x)=tx_1 \\\sigma_t(x)=1-(1-\sigma_{min})t \end{aligned} μt(x)=tx1σt(x)=1(1σmin)t
代入式(10)得到条件向量场 u t u_t ut
u t ( x ∣ x 1 ) = x 1 − ( 1 − σ m i n ) x 1 − ( 1 − σ m i n ) t u_t(x|x_1)=\frac{x_1-(1-\sigma_{min})x}{1-(1-\sigma_{min})t} ut(xx1)=1(1σmin)tx1(1σmin)x
根据式(8)对应的条件流 ψ \psi ψ 有:
ψ t ( x ) = σ t ( x ) x + μ t ( x ) = ( 1 − ( 1 − σ m i n ) t ) x + t x 1 ( 12 ) \psi_t(x)=\sigma_t(x)x+\mu_t(x)=(1-(1-\sigma_{min})t)x+tx_1 \qquad \qquad(12) ψt(x)=σt(x)x+μt(x)=(1(1σmin)t)x+tx1(12)
根据式(2)和式(9)重参数化 ψ \psi ψ 得到目标函数为:
L C F M ( θ ) = E t , q ( x 1 ) , p 0 ( x 0 ) ∥ v t ( ψ t ( x 0 ) ) − ( x 1 − ( 1 − σ m i n ) x 0 ) ∥ 2 ( 13 ) \mathcal{L}_{\mathbf{CFM}}(\theta)=\mathbb{E}_{t,q(x_1),p_0(x_0)}\|v_t(\psi_t(x_0))-(x_1-(1-\sigma_{min})x_0)\|^2 \qquad \qquad(13) LCFM(θ)=Et,q(x1),p0(x0)vt(ψt(x0))(x1(1σmin)x0)2(13)

在这里插入图片描述

Reference

  1. https://zhuanlan.zhihu.com/p/741939590
  2. https://zhuanlan.zhihu.com/p/685921518

相关文章:

  • XUANYING炫影-移动版-智能轻云盒SY900Pro和SY910_RK3528芯片_免拆机通刷固件包
  • 在大型中实施访问控制 语言模型
  • BERT***
  • docker环境添加安装包持久性更新
  • Warm-Flow发布1.7.3 端午节(设计器流和流程图大升级)
  • Unity UI系统中RectTransform详解
  • C#面试问题41-60
  • gitLab 切换中文模式
  • 基于 HT for Web 的轻量化 3D 数字孪生数据中心解决方案
  • superior哥深度学习系列(大纲)
  • gbase8s数据库+mybatis问题记录
  • 004 flutter基础 初始文件讲解(3)
  • 【Vim】高效编辑技巧全解析
  • Flutter 4.x 版本 webview_flutter 嵌套H5
  • 【计算机网络】应用层协议Http——构建Http服务服务器
  • Flutter 嵌套H5 传参数
  • 芯片:数字时代的算力引擎——鲲鹏、升腾、海光、Intel 全景解析
  • 快捷键IDEA
  • [网页五子棋][匹配模式]创建房间类、房间管理器、验证匹配功能,匹配模式小结
  • Python打卡训练营Day40
  • 网站后台开发/2023国内外重大新闻事件10条
  • 政府门户网站集约化建设方案/目前在哪个平台做推广好
  • 网络水果有哪些网站可以做/百度公司注册地址在哪里
  • 武汉市网站开发公司/seo关键词排名优化
  • 查看域名之前做的网站/谷歌手机网页版入口
  • 78建筑网站/网络广告是什么