Flow-GRPO:通过在线 RL 训练 Flow matching 模型
1 前言
本期内容,我们讲Flow-GRPO,他将基于强化学习的GRPO用于Flow matching,并在多个测试指标上获得了巨大的突破,下面让我们来看一下
视频:Flow-GRPO:通过在线 RL 训练 Flow matching 模型
参考论文:Flow-GRPO: Training Flow Matching Models via Online RL
参考代码:Flow-GRPO:
Training Flow Matching Models via Online RL
2 引入
在Flow matching当中,已经可以取得相当不错的效果了,一些基于此开发的模型,如SD3.5的生成质量也相当不错。然而,与最先进的模型相比,SD3.5的指标质量仍然有待提高。比如,与GPT-4o相比,SD3.5显然落后一大截。
在NLP领域,将基于RL(强化学习)的方法引入其中已经证明可以取得相当不错的效果,该方法可以让模型的生成结果更加的趋近于人类的偏好,比如DPO、GRPO等等
强化学习除了应用于NLP领域,在CV领域中也逐渐大放异彩,而Flow-GRPO,就是将GRPO用于Flow matching当中。
3 Flow matching
先回顾一下Flow matching,假定存在x0∼X0x_0\sim X_0x0∼X0为真实的数据样本,x1∼X1x_1\sim X_1x1∼X1为噪声样本,以Rectified flow为例,任意时刻的状态可以表示为
xt=(1−t)x0+tx1x_t = (1-t)x_0+tx_1 xt=(1−t)x0+tx1
其中t∈[0,1]t\in [0,1]t∈[0,1],我们可通过训练得到一个近似向量场vθ(xt,t)v_\theta(x_t,t)vθ(xt,t)
L(θ)=Et,x0,x1[∥v−vθ(xt,t)∥2]\mathcal{L}(\theta)=\mathbb{E}_{t,x_0,x_1}\left[ \Vert v - v_\theta(x_t,t) \Vert^2 \right] L(θ)=Et,x0,x1[∥v−vθ(xt,t)∥2]
其中,向量场v=x1−x0v=x_1-x_0v=x1−x0
4 方法
论文以SD3.5为例,将Flow-GRPO应用于T2I(文生图)当中。熟悉GRPO的小伙伴都知道,要使用GRPO的方法对Flow进行训练,要先解决ODE的问题:
- ODE无法在同一条件下生成多个样本,因此需要进行ODE到SDE的转化
4.1 GRPO
RL的优化目标一般为
maxθE(s0,a0,⋯,sT,aT)∼π0[∑t=0T(R(st,at)−βDKL(πθ(⋅∣st)∣∣πref(⋅∣st)))]\max_\theta \mathbb{E}_{(s_0,a_0,\cdots,s_T,a_T)\sim \pi_0}\left[ \sum\limits_{t=0}^T\left( R(s_t,a_t)-\beta D_{KL}(\pi_\theta(\cdot | s_t)||\pi_{ref}(\cdot|s_t)) \right) \right] θmaxE(s0,a0,⋯,sT,aT)∼π0[t=0∑T(R(st,at)−βDKL(πθ(⋅∣st)∣∣πref(⋅∣st)))]
去噪过程可以表示为一个MDP,给定提示词c,Flow可以得到一组图像{x0i}i=1G\{ x_0^i \}_{i=1}^G{x0i}i=1G,还有对应的一个采样轨迹{(xTi,xT−1i,⋯,x0i)}i=1G\{ (x_T^i,x_{T-1}^i,\cdots,x_0^i) \}_{i=1}^G{(xTi,xT−1i,⋯,x0i)}i=1G,我们可通过组归一化来计算第i张图形的优势,即
A^ti=R(x0i,c)−mean({R(x0i,c)}i=iG)std({R(x0i,c)}i=1G)\hat A_t^i=\frac{R(x_0^i,c)-\text{mean}(\{R(x_0^i,c)\}_{i=i}^G)}{ \text{std}(\{R(x_0^i,c)\}_{i=1}^G)} A^ti=std({R(x0i,c)}i=1G)R(x0i,c)−mean({R(x0i,c)}i=iG)
最大化GRPO的优化目标
JFlow-GRPO(θ)=Ec∼C,{xi}i=1G∼πθold(⋅∣c)f(r,A^,θ,ε,β)\mathcal{J}_{\text{Flow-GRPO}}(\theta)=\mathbb{E}_{c\sim \mathcal{C},\{x^i\}_{i=1}^G\sim \pi_{\theta_{\text{old}}}(\cdot|c)}f(r,\hat A,\theta,\varepsilon,\beta) JFlow-GRPO(θ)=Ec∼C,{xi}i=1G∼πθold(⋅∣c)f(r,A^,θ,ε,β)
其中
f(r,A^,θ,ε,β)=1G∑i=1G1T∑t=0T−1(min(rti(θ)A^ti,clip(1−ε,1+ε)A^ti)−βDKL(πθ∣∣πref)),andrti(θ)=pθ(xt−1i∣xti,c)pθold(xt−1i∣xti,c)f(r,\hat A,\theta,\varepsilon,\beta) = \frac{1}{G}\sum\limits_{i=1}^G\frac{1}{T}\sum\limits_{t=0}^{T-1}\left( \min\left( r_t^i(\theta)\hat A_t^i,\text{clip}(1-\varepsilon,1+\varepsilon)\hat A_t^i \right) - \beta D_{KL}(\pi_\theta||\pi_{ref})\right),\\\text{and}\quad r_t^{i}(\theta)=\frac{p_\theta(x_{t-1}^i|x_t^i,c)}{p_{\theta_{old}}(x_{t-1}^i|x_t^i,c)} f(r,A^,θ,ε,β)=G1i=1∑GT1t=0∑T−1(min(rti(θ)A^ti,clip(1−ε,1+ε)A^ti)−βDKL(πθ∣∣πref)),andrti(θ)=pθold(xt−1i∣xti,c)pθ(xt−1i∣xti,c)
4.2 从 ODE 到 SDE
如上式可见,无论是计算优势函数,还是优化目标当中,都依赖于随机采样来得到不同的轨迹。而基于ODE的去噪过程显然是不满足这一要求的,为此,我们需要把去噪过程从ODE转变为SDE,这样就有了随机性。
那么如何将ODE转化为SDE呢?其实,我们可以得到下面的等式(稍后证明)
dxt=[vt(xt)+σt22t(xt+(1−t)vt(xt))]dt+σtdwˉ(1)d x_t = \left[ v_t(x_t) + \frac{\sigma_t^2}{2t}(x_t+(1-t)v_t(x_t)) \right]dt + \sigma_td\bar w\tag{1} dxt=[vt(xt)+2tσt2(xt+(1−t)vt(xt))]dt+σtdwˉ(1)
dwˉd\bar wdwˉ表示维纳过程增量,σt\sigma_tσt是用于控制稳定程度的
可以看到,Eq.(1)仅仅依赖于向量场vvv,我们完全可以使用学习到的近似向量场vθv_\thetavθ去表示他。我们可以使用任意一个数值求解器,来得到生成轨迹
如欧拉-丸山法
去噪过程为
xt+Δt=xt+[vθ(xt,t)+σt22t(xt+(1−t)vθ(xt,t))]Δt+σtΔtε(2)x_{t+\Delta t} = x_t + \left[ v_{\theta}(x_t,t) + \frac{\sigma_t^2}{2t}(x_t + (1 - t )v_\theta(x_t,t)) \right]\Delta t + \sigma_t\sqrt{ \Delta t}\varepsilon\tag{2} xt+Δt=xt+[vθ(xt,t)+2tσt2(xt+(1−t)vθ(xt,t))]Δt+σtΔtε(2)
其中ε∼N(0,I),σt=at1−t\varepsilon \sim \mathcal{N}(0,I),\sigma_t = a\sqrt{\frac{t}{1 - t}}ε∼N(0,I),σt=a1−tt,aaa是控制噪声水平的超参数。
依据正态分布的性质可知,Eq.(2),也就是πθ(xt−1∣xt,c)\pi_\theta(x_{t-1}|x_t,c)πθ(xt−1∣xt,c)服从正态分布,那么很显然,我们可以直接KL散度为
DKL(π0∣∣πref)=∥xˉt+Δt,θ−xˉt+Δt,ref∥22σt2Δt=Δt2(σt(1−t)2t+1σt)2∥vθ(xt,t)−vref(xt,t)∥2(3)D_{KL}(\pi_0||\pi_{ref})=\frac{\Vert \bar x_{t+\Delta t,\theta} - \bar x_{t + \Delta t ,ref} \Vert^2}{2\sigma_t^2\Delta t} = \frac{ \Delta t}{2}\left( \frac{\sigma_t(1-t)}{2t} +\frac{1}{\sigma_t} \right)^2\Vert v_\theta(x_t,t) - v_{ref}(x_t,t) \Vert^2\tag{3} DKL(π0∣∣πref)=2σt2Δt∥xˉt+Δt,θ−xˉt+Δt,ref∥2=2Δt(2tσt(1−t)+σt1)2∥vθ(xt,t)−vref(xt,t)∥2(3)
这里直接代入KL散度公式来计算即可
5 Denoising Reduction
为了生成高质量的图像,Flow matching通常需要很多的去噪步骤,这使得RL训练的数据收集成本非常高。
论文发现,在进行RL训练的时候,是不需要太多的采样步数的,而在推理的时候保持原始的采样步依然能够获取高质量的样本。
为此,以SD3.5为例,在进行RL训练的时候,令采样时间步T=10;而在推理的时候,保持SD3.5默认设置T=40。
6 模型图
模型的训练流程见下图:
首先,分别采样5个高斯白噪声s0s_0s0,将提示词“A photo of four cups”作为条件,使用SDE数值求解器采样(T=10)得到sTs_TsT。然后将sTs_TsT送进奖励模型,得到R1,R2,R3,R4,⋯,RGR^1,R^2,R^3,R^4,\cdots,R^GR1,R2,R3,R4,⋯,RG作为奖励。用这些奖励根据上面的优势函数计算优势得到A^1,A^2,A^3,A^4,⋯,A^G\hat A^1,\hat A^2,\hat A^3,\hat A^4,\cdots,\hat A^GA^1,A^2,A^3,A^4,⋯,A^G,最后送进Flow-GRPO的损失函数计算损失即可。
7 数学证明
7.1 Eq.(1)证明
要将ODE转换成对应的SDE,就要先从ODE开始,我们有
dxt=vtdt(4)dx_t = v_tdt\tag{4} dxt=vtdt(4)
依据先前讲过的SDE,我们有对应的方程
dxx=fSDE(xt,t)dt+σtdw(5)dx_x = f_{\text{SDE}}(x_t,t)dt +\sigma_td w\tag{5} dxx=fSDE(xt,t)dt+σtdw(5)
我们需要求出fSDEf_{\text{SDE}}fSDE和vtv_tvt的关系式
依据Flow matching所提到的FP方程,Eq.(4)和Eq.(5)都有一个对应的连续性方程来表达概率密度路径ptp_tpt。对于Eq.(5),就是对应的FP方程(证明过程见什么是Fokker-Planck方程),即
KaTeX parse error: Undefined control sequence: \part at position 2: \̲p̲a̲r̲t̲ ̲_tp_t(x) = -\na…
而Eq.(4)对应的连续性方程为:
∂tpt(x)=−∇⋅[vt(xt,t)pt(x)](7)\partial_t p_t(x) = -\nabla \cdot [v_t(x_t,t)p_t(x)]\tag{7} ∂tpt(x)=−∇⋅[vt(xt,t)pt(x)](7)
当ptp_tpt和vtv_tvt的关系满足Eq.(7),则我们说向量场vvv能够生成对应的路径ptp_tpt。Eq.(6)同理。
那么接下来就简单了,联立Eq.(6)和Eq.(7)
−∇⋅[fSDE(xt,t)pt(x)]+12∇2[σt2pt(x)]=−∇⋅[vt(xt,t)pt(x)](8)-\nabla \cdot [f_{\text{SDE}}(x_t,t)p_t(x)]+\frac{1}{2}\nabla^2[\sigma_t^2p_t(x)] =-\nabla \cdot [v_t(x_t,t)p_t(x)]\tag{8} −∇⋅[fSDE(xt,t)pt(x)]+21∇2[σt2pt(x)]=−∇⋅[vt(xt,t)pt(x)](8)
因为
∇logpt(x)=1pt(x)⋅∇pt(x)→∇pt(x)=pt(x)⋅∇logpt(x)\nabla \log p_t(x) = \frac{1}{p_t(x)} \cdot \nabla p_t(x) \to \nabla p_t(x) = p_t(x)\cdot\nabla \log p_t(x) ∇logpt(x)=pt(x)1⋅∇pt(x)→∇pt(x)=pt(x)⋅∇logpt(x)
对Eq.(8)左侧第二项进行一下变化
∇2[σt2pt(x)]=σt2∇2pt(x)=σt2∇⋅(∇pt(x))=σt2∇⋅(pt(x)∇logpt(x))\begin{aligned} \nabla^2[\sigma_t^2p_t(x)] = &\sigma_t^2\nabla^2p_t(x) \\=& \sigma_t^2\nabla\cdot (\nabla p_t(x)) \\= & \sigma_t^2\nabla \cdot (p_t(x)\nabla \log p_t(x)) \end{aligned} ∇2[σt2pt(x)]===σt2∇2pt(x)σt2∇⋅(∇pt(x))σt2∇⋅(pt(x)∇logpt(x))
所以Eq.(8)等于:
−∇⋅[fSDE(xt,t)pt(x)]+12σt2∇⋅(pt(x)∇logpt(x))=−∇⋅[vt(xt,t)pt(x)]−fSDE(xt,t)pt(x)+12σt2pt(x)∇logpt(x)=−vt(xt,t)pt(x)fSDE(xt,t)pt(x)=vt(xt,t)pt(x)+12σt2pt(x)∇logpt(x)fSDE(xt,t)=vt(xt,t)+12σt2∇logpt(x)(9)\begin{aligned} -\nabla \cdot [f_{\text{SDE}}(x_t,t)p_t(x)]+\frac{1}{2}\sigma_t^2\nabla \cdot (p_t(x)\nabla \log p_t(x)) &=-\nabla \cdot [v_t(x_t,t)p_t(x)] \\ -f_{\text{SDE}}(x_t,t)p_t(x) + \frac{1}{2}\sigma_t^2p_t(x)\nabla \log p_t(x) &= - v_t(x_t,t)p_t(x) \\ f_{\text{SDE}}(x_t,t)p_t(x) &= v_t(x_t,t)p_t(x) + \frac{1}{2}\sigma_t^2p_t(x)\nabla \log p_t(x) \\ f_{\text{SDE}}(x_t,t) &= v_t(x_t,t) + \frac{1}{2}\sigma_t^2\nabla \log p_t(x) \end{aligned} \tag{9} −∇⋅[fSDE(xt,t)pt(x)]+21σt2∇⋅(pt(x)∇logpt(x))−fSDE(xt,t)pt(x)+21σt2pt(x)∇logpt(x)fSDE(xt,t)pt(x)fSDE(xt,t)=−∇⋅[vt(xt,t)pt(x)]=−vt(xt,t)pt(x)=vt(xt,t)pt(x)+21σt2pt(x)∇logpt(x)=vt(xt,t)+21σt2∇logpt(x)(9)
这样的话,我们就得到了fSDEf_{\text{SDE}}fSDE和vtv_tvt的关系式了
依据Score-Based Generative Modeling through Stochastic Differential Equations,正向过程Eq.(5)有对应的反向过程为
dxt=[f(xt,t)−g2(t)∇logpt(xt)]dt+g(t)dwˉ(10)dx_t = [f(x_t,t)-g^2(t)\nabla\log p_t(x_t)]dt + g(t)d\bar w\tag{10} dxt=[f(xt,t)−g2(t)∇logpt(xt)]dt+g(t)dwˉ(10)
其中,在本篇文章中,我们是让g(t)=σtg(t) = \sigma_tg(t)=σt,将Eq.(9)代入至Eq.(10)
dxt=[vt(xt,t)+12σt2∇logpt(xt)−σt2∇logpt(xt)]dt+σtdwˉdxt=[vt(xt,t)−σt22∇logpt(xt)]dt+σtdwˉ(11)\begin{aligned} dx_t = & \left[v_t(x_t,t) + \frac{1}{2}\sigma_t^2\nabla \log p_t(x_t) - \sigma_t^2\nabla\log p_t(x_t)\right]dt + \sigma_td\bar w \\dx_t = & \left[v_t(x_t,t)-\frac{\sigma_t^2}{2}\nabla\log p_t(x_t)\right]dt + \sigma_td\bar w \end{aligned}\tag{11} dxt=dxt=[vt(xt,t)+21σt2∇logpt(xt)−σt2∇logpt(xt)]dt+σtdwˉ[vt(xt,t)−2σt2∇logpt(xt)]dt+σtdwˉ(11)
对于Eq.(11),已知vtv_tvt,一旦∇logpt(xt)\nabla \log p_t(x_t)∇logpt(xt)也是已知的,那么就没有未知变量了,也就可以使用数值求解器生成样本了。因此我们还需要求解∇logpt(xt)\nabla \log p_t(x_t)∇logpt(xt)。
对于前向加噪过程,我们有xt=αtx0+βtx1x_t = \alpha_t x_0 + \beta_t x_1xt=αtx0+βtx1,在本期的Flow中,我们将加噪过程定义为αt=1−t;β=t\alpha_t = 1 - t;\beta = tαt=1−t;β=t,xtx_txt服从的概率分布为(假设一维的情况)
pt∣0(xt∣x0)=N(xt∣atx0,βt2I)=1βt2πexp{−(xt−atx0)22βt2}p_{t|0}(x_t|x_0) = \mathcal{N}(x_t|a_tx_0,\beta_t^2I) = \frac{1}{\beta_t\sqrt{2\pi}}\exp\{-\frac{(x_t-a_tx_0)^2}{2\beta_t^2}\} pt∣0(xt∣x0)=N(xt∣atx0,βt2I)=βt2π1exp{−2βt2(xt−atx0)2}
其对数结果为
logpt∣0(xt∣x0)=log(1βt2πexp{−(xt−atx0)22βt2})=log1βt2π−(xt−atx0)22βt2\begin{aligned} \log p_{t|0}(x_t|x_0) = &\log \left( \frac{1}{\beta_t\sqrt{2\pi}}\exp\{-\frac{(x_t-a_tx_0)^2}{2\beta_t^2}\} \right) \\= &\log \frac{1}{\beta_t\sqrt{2\pi}} -\frac{(x_t-a_tx_0)^2}{2\beta_t^2} \end{aligned} logpt∣0(xt∣x0)==log(βt2π1exp{−2βt2(xt−atx0)2})logβt2π1−2βt2(xt−atx0)2
所以
∇logpt∣0(xt∣x0)=−xt−αtx0βt2=βtx1βt2=−x1βt\nabla\log p_{t|0}(x_t|x_0) = -\frac{x_t - \alpha_tx_0}{\beta_t^2} = \frac{\beta_tx_1}{\beta_t^2} = -\frac{x_1}{\beta_t} ∇logpt∣0(xt∣x0)=−βt2xt−αtx0=βt2βtx1=−βtx1
因此
∇logpt(xt)=1pt(xt)∇pt(xt)=1pt(xt)∇∫pt,0(xt,x0)dx0=1pt(xt)∫∇pt,0(xt,x0)dx0=1pt(xt)∫∇[pt∣0(xt∣x0)p0(x0)]dx0=1pt(xt)∫p0(x0)∇pt∣0(xt∣x0)dx0=1pt(xt)∫p0(x0)⋅pt∣0(xt∣x0)∇logpt∣0(xt∣x0)dx0=1pt(xt)∫pt,0(xt,x0)∇logpt∣0(xt∣x0)dx0=1pt(xt)∫p0∣t(x0∣xt)pt(xt)∇logpt∣0(xt∣x0)dx0=∫p0∣t(x0∣xt)∇logpt∣0(xt∣x0)dx0=∫x0∫x1p0∣t(x0,x1∣xt)dx1∇logpt∣0(xt∣x0)dx0=∫x0∫x1p0∣t(x0,x1∣xt)∇logpt∣0(xt∣x0)dx1dx0=E[∇logpt∣0(xt∣x0)∣xt]=E[−x1βt∣xt]=−1βtE[x1∣xt](12)\begin{aligned} \nabla \log p_t(x_t) = & \frac{1}{p_t(x_t)}\nabla p_t(x_t) \\ = & \frac{1}{p_t(x_t)}\nabla\int p_{t,0}(x_t,x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int \nabla p_{t,0}(x_t,x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int \nabla \left[p_{t|0}(x_t|x_0)p_0(x_0)\right]dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_0(x_0) \nabla p_{t|0}(x_t|x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_0(x_0) \cdot p_{t|0}(x_t|x_0)\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_{t,0}(x_t,x_0) \nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_{0|t}(x_0|x_t)p_t(x_t) \nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int p_{0|t}(x_0|x_t)\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int_{x_0}\int_{x_1} p_{0|t}(x_0,x_1|x_t)dx_1\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int_{x_0}\int_{x_1} p_{0|t}(x_0,x_1|x_t) \nabla\log p_{t|0}(x_t|x_0) dx_1 dx_0 \\ = & \mathbb{E}\left[ \nabla \log p_{t|0}(x_t|x_0) |x_t\right] \\ = & \mathbb{E}\left[ -\frac{x_1}{\beta_t} |x_t\right] \\ = & -\frac{1}{\beta_t}\mathbb{E}\left[ x_1|x_t\right] \end{aligned}\tag{12} ∇logpt(xt)==============pt(xt)1∇pt(xt)pt(xt)1∇∫pt,0(xt,x0)dx0pt(xt)1∫∇pt,0(xt,x0)dx0pt(xt)1∫∇[pt∣0(xt∣x0)p0(x0)]dx0pt(xt)1∫p0(x0)∇pt∣0(xt∣x0)dx0pt(xt)1∫p0(x0)⋅pt∣0(xt∣x0)∇logpt∣0(xt∣x0)dx0pt(xt)1∫pt,0(xt,x0)∇logpt∣0(xt∣x0)dx0pt(xt)1∫p0∣t(x0∣xt)pt(xt)∇logpt∣0(xt∣x0)dx0∫p0∣t(x0∣xt)∇logpt∣0(xt∣x0)dx0∫x0∫x1p0∣t(x0,x1∣xt)dx1∇logpt∣0(xt∣x0)dx0∫x0∫x1p0∣t(x0,x1∣xt)∇logpt∣0(xt∣x0)dx1dx0E[∇logpt∣0(xt∣x0)∣xt]E[−βtx1∣xt]−βt1E[x1∣xt](12)
对于向量场vvv,在我们之前的表达式中,是有vt=x1−x0v_t = x_1 - x_0vt=x1−x0。然而,由于路径存在交叉点,所以我们之前说过,我们学习到的vθv_\thetavθ其实并不等于vtv_tvt,而是vtv_tvt的数学期望。我们可以通过以下来证明:
L=∫01Ex0,x1[∥x1−x0−vθ(xt,t)∥2]=∫01Ex0,x1[∥x1−x0∥2+∣∣vθ(xt,t)∣∣2−2(x1−x0)Tvθ(xt,t)]dt=∫01{Ex0,x1[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C=∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C(13)\begin{aligned} \mathcal{L} = & \int_0^1 \mathbb{E}_{x_0,x_1}\left[ \Vert x_1-x_0 - v_\theta(x_t,t)\Vert^2 \right] \\ = & \int_0^1 \mathbb{E}_{x_0,x_1}\left[ \Vert x_1-x_0 \Vert^2 + ||v_\theta(x_t,t)||^2- 2(x_1 - x_0)^Tv_\theta(x_t,t) \right]dt \\ = & \int_0^1 \left\{\mathbb{E}_{x_0,x_1}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right]\right\}dt + C \\ = & \int_0^1 \left\{\mathbb{E}_{x_t}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right]\right\}dt + C \end{aligned}\tag{13} L====∫01Ex0,x1[∥x1−x0−vθ(xt,t)∥2]∫01Ex0,x1[∥x1−x0∥2+∣∣vθ(xt,t)∣∣2−2(x1−x0)Tvθ(xt,t)]dt∫01{Ex0,x1[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C(13)
第一项是因为给定x0,x1x_0,x_1x0,x1,有xt=tx1+(1−t)x0x_t = tx_1 + (1-t)x_0xt=tx1+(1−t)x0,所以可以直接写成关于xtx_txt的数学期望。
第二项我们可以继续变化,由全期望公式:EY=EX[EY(Y∣X)]\mathbb{E}Y = \mathbb{E}_X[\mathbb{E}_Y(Y|X)]EY=EX[EY(Y∣X)],可得
Ex0,x1[(x1−x0)Tvθ(xt,t)]=Ext[Ex0,x1[(x1−x0)Tvθ(xt,t)∣xt]]=Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]\begin{aligned} \mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right] = & \mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) |x_t\right]] \\ = & \mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0) |x_t\right]^Tv_\theta(x_t,t)] \end{aligned} Ex0,x1[(x1−x0)Tvθ(xt,t)]==Ext[Ex0,x1[(x1−x0)Tvθ(xt,t)∣xt]]Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]
所以Eq.(13)为
L=∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]}dt+C=∫01Ext[∥Ex0,x1[x1−x0∣xt]−vθ(xt,t)∥2]dt+C′(14)\begin{aligned} \mathcal{L} = &\int_0^1 \left\{\mathbb{E}_{x_t}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0) |x_t\right]^Tv_\theta(x_t,t)]\right\}dt + C \\ = & \int_0^1 \mathbb{E}_{x_t}\left[ \Vert \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] - v_\theta(x_t,t) \Vert^2 \right]dt +C' \end{aligned}\tag{14} L==∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]}dt+C∫01Ext[∥Ex0,x1[x1−x0∣xt]−vθ(xt,t)∥2]dt+C′(14)
此时我们不难看出,我们所学习到的vθ(xt,t)=Ex0,x1[x1−x0∣xt]v_\theta(x_t,t) = \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t]vθ(xt,t)=Ex0,x1[x1−x0∣xt]
我们继续转化
vθ(xt,t)=Ex0,x1[x1−x0∣xt]=Ex0,x1[x1∣xt]−Ex0,x1[x0∣xt]=Ex0,x1[x1∣xt]−Ex0,x1[xt−tx11−t∣xt]=Ex0,x1[x1∣xt]−Ex0,x1[xt1−t∣xt]+Ex0,x1[tx11−t∣xt]=Ex0,x1[x1∣xt]−xt1−t+t1−tEx0,x1[x1∣xt]=−xt1−t+11−tEx0,x1[x1∣xt]=−xt1−t+11−t⋅(−βt∇logpt(xt))=−xt1−t−t1−t⋅∇logpt(xt)\begin{aligned} v_\theta(x_t,t) = & \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}[x_0|x_t] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}\left[\frac{x_t-tx_1}{1-t}|x_t\right] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}\left[\frac{x_t}{1-t}|x_t\right] + \mathbb{E}_{x_0,x_1}\left[\frac{tx_1}{1-t}|x_t\right] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] -\frac{x_t}{1-t}+ \frac{t}{1-t}\mathbb{E}_{x_0,x_1}\left[x_1|x_t\right] \\ = & -\frac{x_t}{1-t} + \frac{1}{1-t}\mathbb{E}_{x_0,x_1}\left[x_1|x_t\right] \\ = & -\frac{x_t}{1-t} + \frac{1}{1-t}\cdot (-\beta_t\nabla \log p_t(x_t)) \\ = & -\frac{x_t}{1-t} - \frac{t}{1-t}\cdot \nabla \log p_t(x_t) \end{aligned} vθ(xt,t)========Ex0,x1[x1−x0∣xt]Ex0,x1[x1∣xt]−Ex0,x1[x0∣xt]Ex0,x1[x1∣xt]−Ex0,x1[1−txt−tx1∣xt]Ex0,x1[x1∣xt]−Ex0,x1[1−txt∣xt]+Ex0,x1[1−ttx1∣xt]Ex0,x1[x1∣xt]−1−txt+1−ttEx0,x1[x1∣xt]−1−txt+1−t1Ex0,x1[x1∣xt]−1−txt+1−t1⋅(−βt∇logpt(xt))−1−txt−1−tt⋅∇logpt(xt)
把∇\nabla∇单独放等式左侧可得
∇logpt(xt)=−xt−1−ttvθ(xt,t)\nabla\log p_t(x_t) = -\frac{x}{t}-\frac{1-t}{t}v_\theta(x_t,t) ∇logpt(xt)=−tx−t1−tvθ(xt,t)
把它代入到Eq.(11)可得最终的表达式
dxt=[vt(xt,t)−σt22(−xt−1−ttvθ(xt,t))]dt+σtdwˉdxt=[vt(xt,t)+σt22t(x+(1−t)vθ(xt,t))]dt+σtdwˉ(15)\begin{aligned} dx_t = & \left[v_t(x_t,t)-\frac{\sigma_t^2}{2}\left( -\frac{x}{t}-\frac{1-t}{t}v_\theta(x_t,t) \right)\right]dt + \sigma_td\bar w\\ dx_t = & \left[v_t(x_t,t)+\frac{\sigma_t^2}{2t}\left( x+(1-t)v_\theta(x_t,t) \right)\right]dt + \sigma_td\bar w \end{aligned}\tag{15} dxt=dxt=[vt(xt,t)−2σt2(−tx−t1−tvθ(xt,t))]dt+σtdwˉ[vt(xt,t)+2tσt2(x+(1−t)vθ(xt,t))]dt+σtdwˉ(15)
至此得证
8 参考
[1] 深入理解Rectified Flow,完善统一扩散框架 - 知乎
9 结束
好了,本期内容到此为止了,如有问题,还望指出,阿里嘎多!