当前位置: 首页 > news >正文

Flow Matching Guide and Code(3)

Flow Matching Guide and Code(3)

文章目录

  • Flow Matching Guide and Code(3)
    • 4. Flow Matching
      • 4.1 Data
      • 4.2 Building probability paths
      • 4.3 Deriving generating velocity fields
      • 4.4 General conditioning and the Marginalization Trick
      • 4.5 Flow Matching loss
      • 4.6 Solving conditional generation with conditional flows

Code: flow_matching library at https://github.com/facebookresearch/flow_matching

4. Flow Matching

给定一个源分布 ppp 和一个目标分布 qqq流匹配(Flow Matching, FM) (Lipman et al., 2022; Liu et al., 2022; Albergo and Vanden-Eijnden, 2022) 是一种可扩展的方法,用于训练一个由可学习的速度场 utθu_t^\thetautθ 定义的流模型,并解决以下流匹配问题
Findutθgenerating pt,withp0=pandp1=q.(4.1)\mathrm{Find~}u_t^\theta\text{ generating }p_t,\mathrm{~with~}p_0=p\mathrm{~and~}p_1=q. \tag{4.1}Find utθ generating pt, with p0=p and p1=q.(4.1)

上式中的“生成”是依据方程 (3.24) 的含义。回顾图2中的流匹配蓝图,FM框架 (a) 确定一个已知的源分布 ppp 和一个未知的数据目标分布 qqq;(b) 规定一条从 p0=pp_0 = pp0=p 插值到 p1=qp_1 = qp1=q概率路径 ptp_tpt;© 学习一个用神经网络实现的、生成路径 ptp_tpt速度场 utθu_t^\thetautθ;(d) 通过求解带有 utθu_t^\thetautθ 的ODE来从学习到的模型中采样

为了在步骤 © 中学习速度场 utθu_t^\thetautθ,FM通过最小化以下回归损失来实现:

LFM(θ)=EXt∼ptD(ut(Xt),utθ(Xt)),(4.2)\mathcal{L}_{\mathrm{FM}}(\theta)=\mathbb{E}_{X_{t}\sim p_{t}}D\left(u_{t}(X_{t}),u_{t}^{\theta}(X_{t})\right), \tag{4.2}LFM(θ)=EXtptD(ut(Xt),utθ(Xt)),(4.2)

其中 DDD 是向量间的差异度量,例如平方 l2\mathcal{l_2}l2 范数 D(u,v)=∥u−v∥2D(u,v)=\left\|u-v\right\|^2D(u,v)=uv2 。直观地说,FM损失函数鼓励我们可学习的速度场 utθu_t^\thetautθ 去匹配已知能生成所需概率路径 ptp_tpt 的真实速度场 utu_tut。图9描绘了流匹配框架中的主要对象及它们之间的依赖关系。让我们通过描述如何构建 ptp_tptutu_tut,以及损失函数 (4.2) 的一个实际实现,来开始我们对流匹配的阐述。

image-20250911124356380

Figure 9 流匹配(Flow Matching)框架的主要对象及其相互关系。一个流(Flow)速度场(Velocity field) 表示,该速度场定义了一个生成 概率路径(Probability path) 的随机过程。流匹配的核心思想是将构建一个满足期望边界条件(Boundary conditions) 的复杂流(顶行)这一问题,分解(break down) 为构建多个满足更简单边界条件的条件流(conditional flows)(中间行)的问题,后者因此更容易解决。箭头表示了不同对象之间的依赖关系(dependencies)蓝色箭头标示了流匹配框架所采用的(employed)核心关系。“损失(Loss)” 列列出了用于学习速度场的损失函数,其中 CFM 损失(条件流匹配损失)(中间行和底行)是在实践中实际使用的损失。底行列出了在第2节中描述的最简单的流匹配算法实例(instantiation)。

4.1 Data

重申一下,令源样本为随机变量 X0∼pX_0 ∼ pX0p目标样本为随机变量 X1∼qX_1 ∼ qX1q。通常,源样本遵循一个已知的、易于采样的分布(如高斯分布),而目标样本则以一个有限大小的数据集的形式提供给我们。根据应用的不同,目标样本可以构成图像、视频、音频片段或其他类型的高维、富含结构的数据。源样本和目标样本可以是独立的,或者源自一个称为耦合(coupling) 的通用联合分布:

(X0,X1)∼π0,1(X0,X1),(4.3)(X_0,X_1)\sim\pi_{0,1}(X_0,X_1), \tag{4.3}(X0,X1)π0,1(X0,X1),(4.3)
其中,如果没有已知的耦合,则源-目标样本遵循独立耦合(independent coupling) π0,1(X0,X1)=p(X0)q(X1)\pi_{0,1}(X_0,X_1)=p(X_0)q(X_1)π0,1(X0,X1)=p(X0)q(X1)独立源-目标分布的一个常见例子是:考虑从随机高斯噪声向量 X0∼N(0,I)X_0 \sim \mathcal{N}(0, I)X0N(0,I) 生成图像 X1X_1X1。作为依赖耦合(dependent coupling) 的一个例子,可以考虑从其低分辨率版本 X0X_0X0 生成高分辨率图像 X1X_1X1,或从其灰度版本 X0X_0X0 生成彩色视频 X1X_1X1 的情况。

4.2 Building probability paths

流匹配(Flow Matching)通过采用一种条件策略(conditional strategy),极大地简化了设计概率路径 ptp_tpt 及其对应速度场 utu_tut 的问题。作为第一个例子,考虑将 ptp_tpt 的设计条件于一个单独的目标样本 X1=x1X_1 = x_1X1=x1,这就产生了如图3a所示的条件概率路径 pt∣1(x∣x1)p_{t|1}(x|x_1)pt∣1(xx1)。然后,我们可以通过聚合(aggregating) 这些条件概率路径 pt∣1p_{t|1}pt∣1 来构建整体的、边际的概率路径 ptp_tpt

pt(x)=∫pt∣1(x∣x1)q(x1)dx1(4.4)p_t(x) = \int p_{t|1}(x|x_1) q(x_1) dx_1 \tag{4.4}pt(x)=pt∣1(xx1)q(x1)dx1(4.4)

如图3b所示。为了解决流匹配问题,我们希望 ptp_tpt 满足以下边界条件(boundary conditions)
p0=p,p1=q,(4.5)p_0 = p, \quad p_1 = q, \tag{4.5}p0=p,p1=q,(4.5)
也就是说,边际概率路径 ptp_tpt 从时间 t=0t = 0t=0 的源分布 ppp 插值(interpolates) 到时间 t=1t = 1t=1 的目标分布 qqq。这些边界条件可以通过要求条件概率路径满足以下条件来强制执行:
p0∣1(x∣x1)=π0∣1(x∣x1),and p1∣1(x∣x1)=δx1(x),(4.6)p_{0|1}(x|x_1) = \pi_{0|1}(x|x_1), \ \ \text{and} \ p_{1|1}(x|x_1) = \delta_{x_1}(x), \tag{4.6}p0∣1(xx1)=π0∣1(xx1),  and p1∣1(xx1)=δx1(x),(4.6)
其中条件耦合(conditional coupling)$ \pi_{0|1} (x_0 | x_1) = \pi_{0,1} (x_0,x_1) / q(x_1) $,而 δx1δ_{x₁}δx1集中于 x1x_1x1 的狄拉克测度(delta measure)。对于独立耦合(independent coupling) π0,1(x0,x1)=p(x0)q(x1)\pi_{0,1}(x_0, x_1) = p(x_0)q(x_1)π0,1(x0,x1)=p(x0)q(x1),上述第一个约束简化为 p0∣1(x∣x1)=p(x)p_{0|1}(x|x_1) = p(x)p0∣1(xx1)=p(x)。由于狄拉克测度没有密度函数,第二个约束应理解为:对于连续函数 fff,当 t→1t \to 1t1 时,有 ∫pt∣1(x∣y)f(y)dy→f(x)\int p_{t|1}(x|y)f(y)\mathrm{d}y \to f(x)pt∣1(xy)f(y)dyf(x)。请注意,将 (4.6) 代入 (4.4) 即可验证边界条件 (4.5)。

一个满足 (4.6) 中条件的流行条件概率路径的例子在 (2.2) 式中给出:

N(⋅∣tx1,(1−t)2I)→δx1(⋅)as t→1.\mathcal{N}(\cdot|t x_1, (1-t)^2 I) \to \delta_{x_1}(\cdot) \text{ as } t \to 1.N(tx1,(1t)2I)δx1() as t1.

(即:一个均值向 tx1tx_1tx1 移动、方差逐渐缩小的正态分布,在 t=1t=1t=1 时坍缩到点 x1x_1x1 上的狄拉克分布。)

4.3 Deriving generating velocity fields

在具备了边际概率路径 ptp_tpt 之后,我们现在来构建一个生成 ptp_tpt 的速度场 utu_tut。这个生成性的边际速度场 utu_tut 是多个条件速度场 ut(x∣x1)u_t(x|x_1)ut(xx1)(图示见图3c)的平均,其中每个条件速度场满足:

ut(⋅∣x1)generates pt∣1(⋅∣x1).(4.7)u_t(\cdot |x_1) \text{ generates } p_{t|1}(\cdot |x_1). \tag{4.7}ut(x1) generates pt∣1(x1).(4.7)
那么,生成边际路径 pt(x)p_t(x)pt(x) 的边际速度场 ut(x)u_t(x)ut(x)(图示见图3d)由以下公式给出,即在所有目标样本上对条件速度场 ut(x∣x1)u_t(x|x_1)ut(xx1) 进行平均

ut(x)=∫ut(x∣x1)p1∣t(x1∣x)dx1.(4.8)u_t(x) = \int u_t(x|x_1) p_{1|t}(x_1|x) \mathrm{d} x_1. \tag{4.8}ut(x)=ut(xx1)p1∣t(x1x)dx1.(4.8)

为了用已知项来表达上述方程,我们回顾贝叶斯规则

p1∣t(x1∣x)=pt∣1(x∣x1)q(x1)pt(x),(4.9)p_{1|t}(x_1|x)=\frac{p_{t|1}(x|x_1)q(x_1)}{p_t(x)}, \tag{4.9}p1∣t(x1x)=pt(x)pt∣1(xx1)q(x1),(4.9)

该式对所有满足 pt(x)>0p_t(x) > 0pt(x)>0xxx 有定义。方程 (4.8) 可以解释为条件速度 ut(x∣x1)u_t(x|x_1)ut(xx1) 的加权平均,其权重 p1∣t(x1∣x)p_{1|t}(x_1|x)p1∣t(x1x) 代表了在给定当前样本 xxx 的情况下,目标样本 x1x_1x1 的后验概率。方程 (4.8) 的另一种解释可以通过条件期望(见第3.2节)给出。即,如果存在一个随机变量 XtX_tXt,使得 Xt∼pt∣1(⋅∣X1)X_t \sim p_{t|1}(\cdot|X_1)Xtpt∣1(X1),或者等价地,(Xt,X1)(X_t, X_1)(Xt,X1) 的联合分布具有密度 pt,1(x,x1)=pt∣1(x∣x1)q(x1)p_{t,1}(x, x_1) = p_{t|1}(x|x_1) q(x_1)pt,1(x,x1)=pt∣1(xx1)q(x1),那么使用 (3.12) 将 (4.8) 写成一个条件期望,我们得到:

ut(x)=E[ut(Xt∣X1)∣Xt=x],(4.10)u_t(x)=\mathbb{E}\left[u_t(X_t|X_1)|X_t=x\right], \tag{4.10}ut(x)=E[ut(XtX1)Xt=x],(4.10)

这得到了一个对 ut(x)u_t(x)ut(x) 有用的解释:它是 ut(Xt∣X1)u_t(X_t | X_1)ut(XtX1) 在给定 Xt=xX_t = xXt=x 时的最小二乘近似(least-squares approximation)(见第3.2节)。请注意,(4.10) 中的 XtX_tXt 通常是一个与由最终流模型 (3.16) 定义的 XtX_tXt 不同的随机变量,尽管它们共享相同的边际概率 pt(x)p_t(x)pt(x)

4.4 General conditioning and the Marginalization Trick

为了证明上述构造的合理性,我们需要证明在温和的假设下,由方程 (4.8) 和 (4.10) 定义的边际速度场 utu_tut 生成了由方程 (4.4) 定义的边际概率路径 ptp_tpt。证明所需的数学工具是质量守恒定理(定理 2)。为此,让我们考虑一个稍后在本手册中会有用的、更一般的设定。特别地,通过条件于 X1=x1X_1 = x_1X1=x1 来构建条件概率路径和速度场并没有什么特殊之处。正如 Tong et al. (2023) 所指出的,前一节的分析可以推广到条件于任意随机变量 Z∈RmZ \in \mathbb{R}^mZRm(其概率密度函数为 pZp_ZpZ 的情况。这就产生了边际概率路径
pt(x)=∫pt∣Z(x∣z)pZ(z)dz,(4.11)p_t(x)=\int p_{t|Z}(x|z)p_Z(z)dz, \tag{4.11}pt(x)=ptZ(xz)pZ(z)dz,(4.11)
而该路径又由边际速度场生成:

ut(x)=∫ut(x∣z)pZ∣t(z∣x)dz=E[ut(Xt∣Z)∣Xt=x],(4.12)u_t(x)=\int u_t(x|z)p_{Z|t}(z|x)dz=\mathbb{E}\left[u_t(X_t|Z)\mid X_t=x\right], \tag{4.12}ut(x)=ut(xz)pZt(zx)dz=E[ut(XtZ)Xt=x],(4.12)

其中:

  • ut(⋅∣z)u_t(\cdot|z)ut(z) 生成 pt∣Z(⋅∣z)p_{t|Z}(\cdot|z)ptZ(z),
  • pZ∣t(z∣x)=pt∣Z(x∣z)pZ(z)pt(x)p_{Z|t}(z|x) = \frac{p_{t|Z}(x|z) p_Z(z)}{p_t(x)}pZt(zx)=pt(x)ptZ(xz)pZ(z) (由贝叶斯规则得出,给定 pt(x)>0p_t(x) > 0pt(x)>0),
  • Xt∼pt∣Z(⋅∣Z)X_t \sim p_{t|Z}(\cdot|Z)XtptZ(Z)

自然,通过设置 Z=X1Z = X_1Z=X1,我们可以恢复前面章节中的构造。

在证明主要结果之前,我们需要一些正则性假设,总结如下:

Assumption 1
pt∣Z(x∣z)p_{t|Z}(x|z)ptZ(xz)(t,x)(t, x)(t,x)C1([0,1)×Rd)C^1([0, 1) \times \mathbb{R}^d)C1([0,1)×Rd) 函数,且 ut(x∣z)u_t(x|z)ut(xz)(t,x)(t, x)(t,x)C1([0,1)×Rd,Rd)C^1([0, 1) \times \mathbb{R}^d, \mathbb{R}^d)C1([0,1)×Rd,Rd) 函数。此外,pZp_ZpZ 具有有界支撑,即 pZ(x)=0p_Z(x) = 0pZ(x)=0Rm\mathbb{R}^mRm 的某个有界集之外成立。最后,对于所有 x∈Rdx \in \mathbb{R}^dxRdt∈[0,1)t \in [0, 1)t[0,1),有 pt(x)>0p_t(x) > 0pt(x)>0

这些是温和的假设。例如,通过找到一个满足 pZ(z)>0p_Z(z) > 0pZ(z)>0pt∣Z(⋅∣z)>0p_{t|Z}(\cdot|z) > 0ptZ(z)>0 的条件 zzz,可以证明 pt(x)>0p_t(x) > 0pt(x)>0。在实践中,可以通过考虑 (1−(1−t)ε)pt∣Z+(1−t)εN(0,I)(1 - (1-t)\varepsilon) p_{t|Z} + (1-t)\varepsilon \mathcal{N}(0, I)(1(1t)ε)ptZ+(1t)εN(0,I)(其中 ε>0\varepsilon > 0ε>0 是任意小的数)来满足此条件。满足此假设的 pt∣Z(⋅∣z)p_{t|Z}(\cdot|z)ptZ(z) 的一个例子是 (2.2) 式中的路径,其中我们令 Z=X1Z = X_1Z=X1

我们现在准备陈述主要结果:

Theorem 3 (边际化技巧, Marginalization Trick)
在假设 1 下,如果 ut(x∣z)u_t(x|z)ut(xz)条件可积的且生成条件概率路径 pt(⋅∣z)p_t(\cdot|z)pt(z),那么边际速度场 utu_tut 生成边际概率路径 ptp_tpt,对于所有 t∈[0,1)t \in [0, 1)t[0,1) 成立。

在上述定理中,“条件可积”指的是质量守恒定理(3.26)中可积性条件的条件版本,即:
∫01∫∫∥ut(x∣z)∥pt∣Z(x∣z)pZ(x)dzdxdt<∞.(4.13)\int_0^1\int\int\|u_t(x|z)\|p_{t|Z}(x|z)p_Z(x)dz\mathrm{d}x\mathrm{d}t<\infty. \tag{4.13}01∫∫ut(xz)ptZ(xz)pZ(x)dzdxdt<∞.(4.13)

证明
该结果通过验证质量守恒定理(定理 2)的两个条件得出。首先,让我们验证组合 (ut,pt)(u_t, p_t)(ut,pt) 满足连续性方程 (3.25)。因为 ut(⋅∣z)u_t(\cdot|z)ut(z) 生成 pt(⋅∣z)p_t(\cdot|z)pt(z),我们有:

ddtpt(x)=(i)∫ddtpt∣Z(x∣z)pZ(x)dz=(ii)−∫divx[ut(x∣z)pt∣Z(x∣z)]pZ(z)dz=(i)−divx∫ut(x∣z)pt∣Z(x∣z)pZ(z)dz=(iii)−divx[ut(x)pt(x)].\begin{align}\frac{\mathrm{d}}{\mathrm{d}t}p_{t}(x)&\overset{(i)}{=}\int\frac{\mathrm{d}}{\mathrm{d}t}p_{t|Z}(x|z)p_{Z}(x)dz \tag{4.14}\\&\overset{(ii)}{=}-\int\mathrm{div}_{x}\left[u_{t}(x|z)p_{t|Z}(x|z)\right]p_{Z}(z)dz\tag{4.15}\\&\overset{(i)}{=}-\mathrm{div}_x\int u_t(x|z)p_{t|Z}(x|z)p_Z(z)dz\tag{4.16}\\&\overset{(iii)}{=}-\mathrm{div}_x\left[u_t(x)p_t(x)\right].\tag{4.17}\end{align}dtdpt(x)=(i)dtdptZ(xz)pZ(x)dz=(ii)divx[ut(xz)ptZ(xz)]pZ(z)dz=(i)divxut(xz)ptZ(xz)pZ(z)dz=(iii)divx[ut(x)pt(x)].(4.14)(4.15)(4.16)(4.17)

等号 (i)(i)(i) 成立是根据莱布尼茨法则,交换了微分(分别是 ddt\frac{d}{dt}dtddivx\text{div}_xdivx)与积分的顺序,这由以下事实证明是合理的:pt∣Z(x∣z)p_{t|Z}(x|z)ptZ(xz)ut(x∣z)u_t(x|z)ut(xz) 关于 t,xt, xt,xC1C^1C1 的,并且 pZp_ZpZ 具有有界支撑(因此所有被积函数作为有界集上的连续函数都是可积的)。等号 (ii)(ii)(ii) 成立是因为 ut(⋅∣z)u_t(\cdot|z)ut(z) 生成 pt∣Z(⋅∣z)p_{t|Z}(\cdot|z)ptZ(z) 以及定理 2。等号 (iii)(iii)(iii) 成立是通过乘以并除以 pt(x)p_t(x)pt(x)(由假设严格大于零)并使用了 utu_tut 的公式 (4.12)。

为了验证定理 2 的第二个也是最后一个条件,我们应证明 utu_tut 是可积且局部利普希茨连续的。因为 C1C^1C1 函数是局部利普希茨连续的,所以只需检查 ut(x)u_t(x)ut(x) 对所有 (t,x)(t, x)(t,x)C1C^1C1 的。这将由 ut(x∣z)u_t(x|z)ut(xz)pt∣Z(x∣z)p_{t|Z}(x|z)ptZ(xz)C1C^1C1pt(x)>0p_t(x) > 0pt(x)>0(由假设成立)来保证。此外,ut(x)u_t(x)ut(x) 是可积的,因为 ut(x∣z)u_t(x|z)ut(xz) 是条件可积的:

∫01∫∥ut(x)∥pt(x)dxdt≤∫01∫∫∥ut(x∣z)∥pt∣Z(x∣z)pZ(z)dzdxdt<∞,(4.18)\int_0^1\int\|u_t(x)\|p_t(x)\mathrm{d}x\mathrm{d}t\leq\int_0^1\int\int\|u_t(x|z)\|p_{t|Z}(x|z)p_Z(z)dz\mathrm{d}x\mathrm{d}t<\infty, \tag{4.18}01ut(x)pt(x)dxdt01∫∫ut(xz)ptZ(xz)pZ(z)dzdxdt<,(4.18)

其中第一个不等式源于向量詹森不等式 (vector Jensen’s inequality)

4.5 Flow Matching loss

在确立了目标速度场 utu_tut 生成从 pppqqq 的指定概率路径 ptp_tpt 之后,缺失的部分是一个易于处理的损失函数,用于学习一个尽可能接近目标 utu_tut 的速度场模型 utθu_t^\thetautθ

直接使用此损失函数的一个主要障碍是计算目标 utu_tut 是不可行的,因为它需要对整个训练集进行边际化(即,对方程 (4.8) 中的 x1x_1x1 或方程 (4.12) 中的 zzz 进行积分)。

幸运的是,有一族称为布雷格曼散度 (Bregman divergences) 的损失函数,可以仅根据条件速度 ut(x∣z)u_t(x|z)ut(xz) 来为学习 utθ(x)u_t^\theta(x)utθ(x) 提供无偏梯度

image-20250911183029935

Figure 10 布雷格曼散度

布雷格曼散度测量两个向量 u,v∈Rdu, v \in \mathbb{R}^du,vRd 之间的差异性,定义为:

D(u,v):=Φ(u)−[Φ(v)+⟨u−v,∇Φ(v)⟩],(4.19)D(u,v):=\Phi(u)-\left[\Phi(v)+\langle u-v,\nabla\Phi(v)\rangle\right],\tag{4.19}D(u,v):=Φ(u)[Φ(v)+uv,∇Φ(v)⟩],(4.19)

其中 Φ:Rd→R\Phi: \mathbb{R}^d \to \mathbb{R}Φ:RdR 是一个在某个凸集 Ω⊂Rd\Omega \subset \mathbb{R}^dΩRd 上定义的严格凸函数。如图10所示,布雷格曼散度测量了 Φ(u)\Phi(u)Φ(u) 与在 vvv 处展开并 evaluated at uuuΦ\PhiΦ 的线性近似之间的差异。因为线性近似是凸函数的全局下界,所以有 D(u,v)≥0D(u, v) \ge 0D(u,v)0。此外,由于 Φ\PhiΦ 是严格凸的,当且仅当 u=vu = vu=vD(u,v)=0D(u, v) = 0D(u,v)=0

最基础的布雷格曼散度是平方欧几里得距离 D(u,v)=∣u−v∣2D(u, v) = |u - v|^2D(u,v)=uv2,这源于选择 Φ(u)=∣u∣2\Phi(u) = |u|^2Φ(u)=u2

使得布雷格曼散度对 Flow Matching 有用的关键性质是它们关于第二个参数的梯度是仿射不变的 (affine invariant) (Holderrieth et al., 2024):

∇vD(au1+bu2,v)=a∇vD(u1,v)+b∇vD(u2,v),for any a+b=1,(4.20)\nabla_v D(a u_1 + b u_2, v) = a \nabla_v D(u_1, v) + b \nabla_v D(u_2, v), \text{ for any } a + b = 1, \tag{4.20}vD(au1+bu2,v)=avD(u1,v)+bvD(u2,v), for any a+b=1,(4.20)

这可以从方程 (4.19) 验证。仿射不变性允许我们如下交换期望和梯度:

∇vD(E[Y],v)=E[∇vD(Y,v)]for any RVY∈Rd.(4.21)\nabla_vD(\mathbb{E}[Y],v)=\mathbb{E}[\nabla_vD(Y,v)]\quad\text{for any RV}\quad {Y}\in\mathbb{R}^d. \tag{4.21}vD(E[Y],v)=E[vD(Y,v)]for any RVYRd.(4.21)

Flow Matching 损失采用一个布雷格曼散度,将我们的可学习速度 utθ(x)u_t^\theta(x)utθ(x) 回归 (regress) 到沿着概率路径 ptp_tpt 的目标速度 ut(x)u_t(x)ut(x) 上:

LFM(θ)=Et,Xt∼ptD(ut(Xt),utθ(Xt)),(4.22)\mathcal{L}_{\mathrm{FM}}(\theta)=\mathbb{E}_{t,X_t\sim p_t}D(u_t(X_t),u_t^\theta(X_t)), \tag{4.22}LFM(θ)=Et,XtptD(ut(Xt),utθ(Xt)),(4.22)

其中时间 t∼U[0,1]t \sim \mathcal{U}[0, 1]tU[0,1]。然而,如上所述,目标速度 utu_tut 是难以处理的,因此上面的损失无法按原样计算。相反,我们考虑更简单且易于处理的条件流匹配 (Conditional Flow Matching, CFM) 损失

LCFM(θ)=Et,Z,Xt∼pt∣Z(⋅∣Z)D(ut(Xt∣Z),utθ(Xt)).(4.23)\mathcal{L}_{\mathrm{CFM}}(\theta)=\mathbb{E}_{t,Z,X_t\sim p_{t|Z}(\cdot|Z)}D(u_t(X_t|Z),u_t^\theta(X_t)).\tag{4.23}LCFM(θ)=Et,Z,XtptZ(Z)D(ut(XtZ),utθ(Xt)).(4.23)

这两个损失对于学习目的是等价的,因为它们的梯度重合 (Holderrieth et al., 2024):

定理 4. Flow Matching 损失和 Conditional Flow Matching 损失的梯度重合:
∇θLFM(θ)=∇θLCFM(θ).(4.24)\nabla_{\theta}\mathcal{L}_{FM}(\theta)=\nabla_{\theta}\mathcal{L}_{CFM}(\theta).\tag{4.24}θLFM(θ)=θLCFM(θ).(4.24)
特别地,Conditional Flow Matching 损失的最小化器就是边际速度场 ut(x)u_t(x)ut(x)

证明. 证明遵循直接计算:

∇θLFM(θ)=∇θEt,Xt∼ptD(ut(Xt),utθ(Xt))=Et,Xt∼pt∇θD(ut(Xt),utθ(Xt))=(i)Et,Xt∼pt∇vD(ut(Xt),utθ(Xt))∇θutθ(Xt)=(4.12)Et,Xt∼pt∇vD(EZ∼pZ∣t(⋅∣Xt)[ut(Xt∣Z)],utθ(Xt))∇θutθ(Xt)=(ii)Et,Xt∼ptEZ∼pZ∣t(⋅∣Xt)[∇vD(ut(Xt∣Z),utθ(Xt))∇θutθ(Xt)]=⁡(iii)Et,Xt∼ptEZ∼pZ∣t(⋅∣Xt)[∇θD(ut(Xt∣Z),utθ(Xt))]=⁡(iv)∇θEt,Z∼q,Xt∼pt∣Z(⋅∣Z)[D(ut(Xt∣Z),utθ(Xt))]=∇θLCFM(θ)\begin{aligned}\nabla_{\theta}\mathcal{L}_{\mathrm{FM}}(\theta)&=\nabla_{\theta}\mathbb{E}_{t,X_{t}\sim p_{t}}D(u_{t}(X_{t}),u_{t}^{\theta}(X_{t}))\\&=\mathbb{E}_{t,X_{t}\sim p_{t}}\nabla_{\theta}D(u_{t}(X_{t}),u_{t}^{\theta}(X_{t}))\\&\overset{(i)}{=}\mathbb{E}_{t,X_{t}\sim p_{t}}\nabla_{v}D(u_{t}(X_{t}),u_{t}^{\theta}(X_{t}))\nabla_{\theta}u_{t}^{\theta}(X_{t})\\&\overset{(4.12)}{=}\mathbb{E}_{t,X_t\thicksim p_t}\nabla_vD(\mathbb{E}_{Z\sim p_{Z|t}(\cdot|X_t)}[u_t(X_t|Z)],u_t^\theta(X_t))\nabla_\theta u_t^\theta(X_t)\\&\overset{(ii)}{=}\mathbb{E}_{t,X_t\sim p_t}\mathbb{E}_{Z\sim p_{Z|t}(\cdot|X_t)}\left[\nabla_vD(u_t(X_t|Z),u_t^\theta(X_t))\nabla_\theta u_t^\theta(X_t)\right]\\&\overset{(iii)}{\operatorname*{=}}\mathbb{E}_{t,X_{t}\sim p_{t}}\mathbb{E}_{Z\sim p_{Z|t}(\cdot|X_{t})}[\nabla_{\theta}D(u_{t}(X_{t}|Z),u_{t}^{\theta}(X_{t}))]\\&\overset{(iv)}{\operatorname*{=}}\nabla_\theta\mathbb{E}_{t,Z\sim q,X_t\sim p_{t|Z}(\cdot|Z)}[D(u_t(X_t|Z),u_t^\theta(X_t))]\\&=\nabla_{\theta}\mathcal{L}_{\mathrm{CFM}}(\theta)\end{aligned}θLFM(θ)=θEt,XtptD(ut(Xt),utθ(Xt))=Et,XtptθD(ut(Xt),utθ(Xt))=(i)Et,XtptvD(ut(Xt),utθ(Xt))θutθ(Xt)=(4.12)Et,XtptvD(EZpZt(Xt)[ut(XtZ)],utθ(Xt))θutθ(Xt)=(ii)Et,XtptEZpZt(Xt)[vD(ut(XtZ),utθ(Xt))θutθ(Xt)]=(iii)Et,XtptEZpZt(Xt)[θD(ut(XtZ),utθ(Xt))]=(iv)θEt,Zq,XtptZ(Z)[D(ut(XtZ),utθ(Xt))]=θLCFM(θ)

其中在 (i)(i)(i)(iii)(iii)(iii) 中我们使用了链式法则;(ii)(ii)(ii) 源于将方程 (4.21) 条件于 XtX_tXt 应用;在 (iv)(iv)(iv) 中我们使用了贝叶斯规则。

用于学习条件期望的布雷格曼散度. 定理 4 是利用布雷格曼散度学习条件期望的一个更一般结果的特定实例,如下所述。它将在本手册中全程使用,并为 Flow Matching 背后所有可扩展的损失提供基础:

Proposition(命题) 1 (用于学习条件期望的布雷格曼散度). 设 X∈SXX \in S_XXSX, Y∈SYY \in S_YYSY 是状态空间 SXS_XSX, SYS_YSY 上的随机变量,且 g:Rp×SX→Rng: \mathbb{R}^p \times S_X \to \mathbb{R}^ng:Rp×SXRn, (θ,x)↦gθ(x)(\theta, x) \mapsto g_\theta(x)(θ,x)gθ(x),其中 θ∈Rp\theta \in \mathbb{R}^pθRp 表示可学习参数。设 Dx(u,v)D_x(u, v)Dx(u,v), x∈SXx \in S_XxSX 是一个在凸集 Ω⊂Rn\Omega \subset \mathbb{R}^nΩRn 上的布雷格曼散度,该凸集包含 fff 的像。那么,∇θEX,YDX(Y,gθ(X))=∇θEXDX(E[Y∣X],gθ(X)).(4.25)\nabla_\theta\mathbb{E}_{X,Y}D_X\left(Y,g^\theta(X)\right)=\nabla_\theta\mathbb{E}_XD_X\left(\mathbb{E}\left[Y\mid X\right],g^\theta(X)\right).\tag{4.25}θEX,YDX(Y,gθ(X))=θEXDX(E[YX],gθ(X)).(4.25)

特别地,对于所有满足 pX(x)>0p_X(x) > 0pX(x)>0xxxgθ(x)g_\theta(x)gθ(x) 关于 θ\thetaθ全局最小值满足

gθ(x)=E[Y∣X=x].(4.26)g^\theta(x)=\mathbb{E}\left[Y\mid X=x\right].\tag{4.26}gθ(x)=E[YX=x].(4.26)

证明. 我们假设 gθg_\thetagθ 关于 θ\thetaθ 可微,并且 XXXYYY 的分布,以及 DxD_xDxggg 允许交换微分和积分,展开:

∇θEX,YDX(Y,gθ(X))=(i)EX[E[∇vDX(Y,gθ(X))∇θgθ(X)∣X]]=⁡(ii)EX[∇vDX(E[Y∣X],gθ(X))∇θgθ(X)]=⁡(iii)EX[∇θDX(E[Y∣X],gθ(X))]=∇θEXDX(E[Y∣X],gθ(X)),\begin{aligned}\nabla_{\theta}\mathbb{E}_{X,Y}D_{X}\left(Y,g^{\theta}(X)\right)&\overset{(i)}{=}\mathbb{E}_X\left[\mathbb{E}\left[\nabla_vD_X\left(Y,g^\theta(X)\right)\nabla_\theta g^\theta(X)\mid X\right]\right]\\&\overset{(ii)}{\operatorname*{=}}\mathbb{E}_{X}\left[\nabla_{v}D_{X}\left(\mathbb{E}\left[Y\mid X\right],g^{\theta}(X)\right)\nabla_{\theta}g^{\theta}(X)\right]\\&\overset{(iii)}{\operatorname*{=}}\mathbb{E}_{X}\left[\nabla_{\theta}D_{X}\left(\mathbb{E}\left[Y\mid X\right],g^{\theta}(X)\right)\right]\\&=\nabla_{\theta}\mathbb{E}_{X}D_{X}\left(\mathbb{E}\left[Y\mid X\right],g^{\theta}(X)\right),\end{aligned}θEX,YDX(Y,gθ(X))=(i)EX[E[vDX(Y,gθ(X))θgθ(X)X]]=(ii)EX[vDX(E[YX],gθ(X))θgθ(X)]=(iii)EX[θDX(E[YX],gθ(X))]=θEXDX(E[YX],gθ(X)),

其中 (i)(i)(i) 遵循链式法则和期望的塔性质 (3.11)。等号 (ii)(ii)(ii) 遵循 (4.21)。等号 (iii)(iii)(iii) 再次使用了链式法则。最后,对于每个满足 pX(x)>0p_X(x) > 0pX(x)>0x∈SXx \in S_XxSX,我们可以选择 gθ(x)=E[Y∣X=x]g_\theta(x) = \mathbb{E} [Y | X = x]gθ(x)=E[YX=x],得到 E∗X[DX(E[Y∣X],g∗θ(X))]=0\mathbb{E}*{X} \left[ D_X( \mathbb{E} [Y | X], g*\theta(X) ) \right] = 0EX[DX(E[YX],gθ(X))]=0,这必然是关于 θ\thetaθ 的全局最小值。

通过选择 X=XtX = X_tX=Xt, Y=ut(Xt∣Z)Y = u_t(X_t|Z)Y=ut(XtZ), gθ(x)=utθ(x)g_\theta(x) = u_t^\theta(x)gθ(x)=utθ(x),并对 t∼U[0,1]t \sim \mathcal{U}[0, 1]tU[0,1] 取期望,可以立即从命题 1 证明定理 4。

一般时间分布. FM 损失的一个有用变体是从均匀分布以外的分布中采样时间 ttt。具体来说,考虑 t∼ω(t)t \sim \omega(t)tω(t),其中 ω\omegaω[0,1][0, 1][0,1] 上的一个概率密度函数。这导致以下加权目标

LCFM(θ)=Et∼ω,Z,XtD(ut(Xt∣Z),utθ(Xt))=Et∼U,Z,Xtω(t)D(ut(Xt∣Z),utθ(Xt)).(4.27)\mathcal{L}_{\mathrm{CFM}}(\theta)=\mathbb{E}_{t\sim\omega,Z,X_t}D(u_t(X_t|Z),u_t^\theta(X_t))=\mathbb{E}_{t\sim U,Z,X_t\omega}(t)D(u_t(X_t|Z),u_t^\theta(X_t)).\tag{4.27}LCFM(θ)=Etω,Z,XtD(ut(XtZ),utθ(Xt))=EtU,Z,Xtω(t)D(ut(XtZ),utθ(Xt)).(4.27)

尽管在数学上是等价的,但在大规模图像生成任务中,采样 t∼ωt \sim \omegatω 比使用权重 ω(t)\omega(t)ω(t) 能带来更好的性能 (Esser et al., 2024)。

4.6 Solving conditional generation with conditional flows

到目前为止,我们已将训练流模型 utθu_t^\thetautθ 的问题简化为以下三个步骤:
(i) 寻找能够产生满足 (4.5) 中边界条件的边际概率路径 pt(x)p_t(x)pt(x)条件概率路径 pt∣Z(x∣z)p_{t|Z}(x|z)ptZ(xz)
(ii) 寻找能够生成该条件概率路径的条件速度场 ut(x∣z)u_t(x|z)ut(xz)
(iii) 使用条件流匹配损失(参见方程 (4.23))进行训练。

我们现在讨论如何具体实现步骤 (i) 和 (ii),即如何设计此类条件概率路径和速度场。我们将提出一种灵活的方法,通过条件流(conditional flows) 的特定构造来设计此类条件概率路径和速度场。其核心思想如下:定义一个满足边界条件 (4.6) 的流模型 Xt∣1X_{t|1}Xt∣1(类似于 (3.16)),然后通过微分 (3.20) 从中提取速度场。此过程同时定义了 pt∣1(x∣x1)p_{t|1}(x|x_1)pt∣1(xx1)ut(x∣x1)u_t(x|x_1)ut(xx1)。更详细地,定义条件流模型

Xt∣1=ψt(X0∣x1),where X0∼π0∣1(⋅∣x1),(4.28)X_{t|1} = \psi_t(X_0 \mid x_1), \quad \text{where } X_0 \sim \pi_{0|1} (\cdot \mid x_1),\tag{4.28}Xt∣1=ψt(X0x1),where X0π0∣1(x1),(4.28)

其中 ψ:[0,1)×Rd×Rd→Rd\psi: [0, 1) \times \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}^dψ:[0,1)×Rd×RdRd 是一个条件流,由下式定义:

ψt(x∣x1)={xt=0x1t=1,(4.29)\begin{align*} \psi_t(x|x_1) = \begin{cases} x & t = 0 \\ x_1 & t = 1 \end{cases}, \end{align*} \tag{4.29}ψt(xx1)={xx1t=0t=1,(4.29)

该函数在 (t,x)(t, x)(t,x)光滑,并且在 xxx 上是微分同胚。(此处的“光滑”是指 ψt(x∣x1)\psi_t(x|x_1)ψt(xx1) 关于 tttxxx 的所有导数均存在且连续:即 C∞([0,1)×Rd,Rd)C^\infty([0, 1) \times \mathbb{R}^d, \mathbb{R}^d)C([0,1)×Rd,Rd)。这些条件可以放宽至 C2([0,1)×Rd,Rd)C^2([0, 1) \times \mathbb{R}^d, \mathbb{R}^d)C2([0,1)×Rd,Rd),但会牺牲一定的简洁性。)

推送映射公式 (3.15)Xt∣1X_{t|1}Xt∣1 的概率密度定义为:

pt∣1(x∣x1):=[ψt(⋅∣x1)♯π0∣1(⋅∣x1)](x),(4.30)p_{t|1}(x|x_1):=\begin{bmatrix}\psi_t(\cdot|x_1)_\sharp\pi_{0|1}(\cdot|x_1)\end{bmatrix}(x), \tag{4.30}pt∣1(xx1):=[ψt(x1)π0∣1(x1)](x),(4.30)

尽管在 CFM 损失的实际优化中我们不需要此表达式,但它在理论上用于证明 pt∣1p_{t|1}pt∣1 满足两个边界条件 (4.6)。首先,根据 (4.29),ψ0(⋅∣x1)\psi_0(\cdot | x_1)ψ0(x1) 是恒等映射,在时间 t=0t = 0t=0 时保持 π0∣1(⋅∣x1)\pi_{0|1}(\cdot | x_1)π0∣1(x1) 不变。其次,ψ1(⋅∣x1)=x1\psi_1(\cdot | x_1) = x_1ψ1(x1)=x1 是常值映射,随着 t→1t \to 1t1 将所有概率质量集中于 x1x_1x1。此外,请注意 ψt(⋅∣x1)\psi_t(\cdot | x_1)ψt(x1) 对于 t∈[0,1)t \in [0, 1)t[0,1) 是光滑微分同胚。因此,根据流与速度场的等价性(第 3.4.1 节),存在一个唯一的光滑条件速度场(参见方程 (3.20)),其形式为:
ut(x∣x1)=ψ˙t(ψt−1(x∣x1)∣x1).(4.31)u_t(x|x_1)=\dot{\psi}_t(\psi_t^{-1}(x|x_1)|x_1). \tag{4.31}ut(xx1)=ψ˙t(ψt1(xx1)x1).(4.31)

总结而言:我们进一步将寻找条件路径及相应生成速度场的任务,简化为只需构建一个满足 (4.29) 的条件流 ψt(⋅∣x1)\psi_t(\cdot | x_1)ψt(x1)。在第 4.7 节中,我们将选择一个特别简单且具有某些理想性质(条件最优传输流)的 ψt(x∣x1)\psi_t(x | x_1)ψt(xx1),它将引致第 1 节中看到的标准 Flow Matching 算法。在第 4.8 节中,我们将讨论一个特定且众所周知的条件流族,即仿射流(affine flows),其中包含扩散模型文献中的一些已知示例。在第 5 节中,我们将使用条件流在流形(manifold) 上定义 Flow Matching,从而展示此方法的灵活性。


文章转载自:

http://UTxBObRT.bpdcw.cn
http://sKmn4R67.bpdcw.cn
http://h3KZNTR0.bpdcw.cn
http://IJerBcCC.bpdcw.cn
http://hynMFqBZ.bpdcw.cn
http://mxyoCeME.bpdcw.cn
http://sNPd4BHb.bpdcw.cn
http://M2wme2ZC.bpdcw.cn
http://WU8Y3K2k.bpdcw.cn
http://qwJvGw9z.bpdcw.cn
http://Ux3TXCVC.bpdcw.cn
http://IyR1e94j.bpdcw.cn
http://jCn3DGIn.bpdcw.cn
http://0KM2358U.bpdcw.cn
http://RaiTxdnC.bpdcw.cn
http://pN0JUAmo.bpdcw.cn
http://ZTnZSJjW.bpdcw.cn
http://jgjPxjIQ.bpdcw.cn
http://cm73VUKg.bpdcw.cn
http://MGJlV590.bpdcw.cn
http://Pbx2nNhP.bpdcw.cn
http://wxSWuTah.bpdcw.cn
http://jA4VjPz8.bpdcw.cn
http://3KmGwh25.bpdcw.cn
http://ZBu5DH1g.bpdcw.cn
http://NeFiJhxC.bpdcw.cn
http://Xp2S58Kw.bpdcw.cn
http://DTRBc8a2.bpdcw.cn
http://0ydYO4BP.bpdcw.cn
http://HbPqCcgY.bpdcw.cn
http://www.dtcms.com/a/378503.html

相关文章:

  • 内存泄漏一些事
  • 嵌入式学习day47-硬件-imx6ul-LED、Beep
  • 【数据结构】队列详解
  • C++/QT
  • GPT 系列论文1-2 两阶段半监督 + zero-shot prompt
  • 昆山精密机械公司8个Solidworks共用一台服务器
  • MasterGo钢笔Pen
  • 【算法--链表】143.重排链表--通俗讲解
  • 数据库的回表
  • 《Learning Langchain》阅读笔记13-Agent(1):Agent Architecture
  • MySQL索引(二):覆盖索引、最左前缀原则与索引下推详解
  • 【WS63】星闪开发资源整理
  • 守住矿山 “生命线”!QB800系列在线绝缘监测在矿用提升机电传系统应用方案
  • Altium Designer(AD)原理图更新PCB后所有器件变绿解决方案
  • DIFY 项目中通过 Makefile 调用 Dockerfile 并使用 sudo make build-web 命令构建 web 镜像的方法和注意事项
  • 联合索引最左前缀原则原理索引下推
  • 平衡车 -- 速度环
  • BPE算法深度解析:从零到一构建语言模型的词元化引擎
  • DIPMARK:一种隐蔽、高效且具备鲁棒性的大语言模型水印技术
  • mysql多表联查
  • 审美积累 | 移动端仪表盘
  • 面阵结构光3D相机三维坐标计算
  • 【大前端++】几大特征
  • 【持续更新】高质量的项目开发过程(C++)(前后端)
  • 淘宝商品视频批量自动化获取的常见渠道分享
  • ABAP 将多层json逐层解析转成内表
  • 一样的糖果
  • linux x86_64中打包qt
  • Windows 10 22H2 64位 【原版+优化版、版本号:19045.6332】
  • 学习日记-CSS-day53-9.11