当前位置: 首页 > news >正文

Flow-GRPO:通过在线 RL 训练 Flow matching 模型

1 前言

本期内容,我们讲Flow-GRPO,他将基于强化学习的GRPO用于Flow matching,并在多个测试指标上获得了巨大的突破,下面让我们来看一下
视频:Flow-GRPO:通过在线 RL 训练 Flow matching 模型

参考论文:Flow-GRPO: Training Flow Matching Models via Online RL

参考代码:Flow-GRPO:
Training Flow Matching Models via Online RL

2 引入

在Flow matching当中,已经可以取得相当不错的效果了,一些基于此开发的模型,如SD3.5的生成质量也相当不错。然而,与最先进的模型相比,SD3.5的指标质量仍然有待提高。比如,与GPT-4o相比,SD3.5显然落后一大截。

在NLP领域,将基于RL(强化学习)的方法引入其中已经证明可以取得相当不错的效果,该方法可以让模型的生成结果更加的趋近于人类的偏好,比如DPO、GRPO等等

强化学习除了应用于NLP领域,在CV领域中也逐渐大放异彩,而Flow-GRPO,就是将GRPO用于Flow matching当中。

在这里插入图片描述

3 Flow matching

先回顾一下Flow matching,假定存在x0∼X0x_0\sim X_0x0X0为真实的数据样本,x1∼X1x_1\sim X_1x1X1为噪声样本,以Rectified flow为例,任意时刻的状态可以表示为

xt=(1−t)x0+tx1x_t = (1-t)x_0+tx_1 xt=(1t)x0+tx1
其中t∈[0,1]t\in [0,1]t[0,1],我们可通过训练得到一个近似向量场vθ(xt,t)v_\theta(x_t,t)vθ(xt,t)
L(θ)=Et,x0,x1[∥v−vθ(xt,t)∥2]\mathcal{L}(\theta)=\mathbb{E}_{t,x_0,x_1}\left[ \Vert v - v_\theta(x_t,t) \Vert^2 \right] L(θ)=Et,x0,x1[vvθ(xt,t)2]
其中,向量场v=x1−x0v=x_1-x_0v=x1x0

4 方法

论文以SD3.5为例,将Flow-GRPO应用于T2I(文生图)当中。熟悉GRPO的小伙伴都知道,要使用GRPO的方法对Flow进行训练,要先解决ODE的问题:

  • ODE无法在同一条件下生成多个样本,因此需要进行ODE到SDE的转化

4.1 GRPO

RL的优化目标一般为
max⁡θE(s0,a0,⋯,sT,aT)∼π0[∑t=0T(R(st,at)−βDKL(πθ(⋅∣st)∣∣πref(⋅∣st)))]\max_\theta \mathbb{E}_{(s_0,a_0,\cdots,s_T,a_T)\sim \pi_0}\left[ \sum\limits_{t=0}^T\left( R(s_t,a_t)-\beta D_{KL}(\pi_\theta(\cdot | s_t)||\pi_{ref}(\cdot|s_t)) \right) \right] θmaxE(s0,a0,,sT,aT)π0[t=0T(R(st,at)βDKL(πθ(st)∣∣πref(st)))]
去噪过程可以表示为一个MDP,给定提示词c,Flow可以得到一组图像{x0i}i=1G\{ x_0^i \}_{i=1}^G{x0i}i=1G,还有对应的一个采样轨迹{(xTi,xT−1i,⋯,x0i)}i=1G\{ (x_T^i,x_{T-1}^i,\cdots,x_0^i) \}_{i=1}^G{(xTi,xT1i,,x0i)}i=1G,我们可通过组归一化来计算第i张图形的优势,即
A^ti=R(x0i,c)−mean({R(x0i,c)}i=iG)std({R(x0i,c)}i=1G)\hat A_t^i=\frac{R(x_0^i,c)-\text{mean}(\{R(x_0^i,c)\}_{i=i}^G)}{ \text{std}(\{R(x_0^i,c)\}_{i=1}^G)} A^ti=std({R(x0i,c)}i=1G)R(x0i,c)mean({R(x0i,c)}i=iG)
最大化GRPO的优化目标
JFlow-GRPO(θ)=Ec∼C,{xi}i=1G∼πθold(⋅∣c)f(r,A^,θ,ε,β)\mathcal{J}_{\text{Flow-GRPO}}(\theta)=\mathbb{E}_{c\sim \mathcal{C},\{x^i\}_{i=1}^G\sim \pi_{\theta_{\text{old}}}(\cdot|c)}f(r,\hat A,\theta,\varepsilon,\beta) JFlow-GRPO(θ)=EcC{xi}i=1Gπθold(c)f(r,A^,θ,ε,β)
其中
f(r,A^,θ,ε,β)=1G∑i=1G1T∑t=0T−1(min⁡(rti(θ)A^ti,clip(1−ε,1+ε)A^ti)−βDKL(πθ∣∣πref)),andrti(θ)=pθ(xt−1i∣xti,c)pθold(xt−1i∣xti,c)f(r,\hat A,\theta,\varepsilon,\beta) = \frac{1}{G}\sum\limits_{i=1}^G\frac{1}{T}\sum\limits_{t=0}^{T-1}\left( \min\left( r_t^i(\theta)\hat A_t^i,\text{clip}(1-\varepsilon,1+\varepsilon)\hat A_t^i \right) - \beta D_{KL}(\pi_\theta||\pi_{ref})\right),\\\text{and}\quad r_t^{i}(\theta)=\frac{p_\theta(x_{t-1}^i|x_t^i,c)}{p_{\theta_{old}}(x_{t-1}^i|x_t^i,c)} f(r,A^,θ,ε,β)=G1i=1GT1t=0T1(min(rti(θ)A^ti,clip(1ε,1+ε)A^ti)βDKL(πθ∣∣πref)),andrti(θ)=pθold(xt1ixti,c)pθ(xt1ixti,c)

4.2 从 ODE 到 SDE

如上式可见,无论是计算优势函数,还是优化目标当中,都依赖于随机采样来得到不同的轨迹。而基于ODE的去噪过程显然是不满足这一要求的,为此,我们需要把去噪过程从ODE转变为SDE,这样就有了随机性。

那么如何将ODE转化为SDE呢?其实,我们可以得到下面的等式(稍后证明)
dxt=[vt(xt)+σt22t(xt+(1−t)vt(xt))]dt+σtdwˉ(1)d x_t = \left[ v_t(x_t) + \frac{\sigma_t^2}{2t}(x_t+(1-t)v_t(x_t)) \right]dt + \sigma_td\bar w\tag{1} dxt=[vt(xt)+2tσt2(xt+(1t)vt(xt))]dt+σtdwˉ(1)
dwˉd\bar wdwˉ表示维纳过程增量,σt\sigma_tσt是用于控制稳定程度的

可以看到,Eq.(1)仅仅依赖于向量场vvv,我们完全可以使用学习到的近似向量场vθv_\thetavθ去表示他。我们可以使用任意一个数值求解器,来得到生成轨迹

如欧拉-丸山法

去噪过程为
xt+Δt=xt+[vθ(xt,t)+σt22t(xt+(1−t)vθ(xt,t))]Δt+σtΔtε(2)x_{t+\Delta t} = x_t + \left[ v_{\theta}(x_t,t) + \frac{\sigma_t^2}{2t}(x_t + (1 - t )v_\theta(x_t,t)) \right]\Delta t + \sigma_t\sqrt{ \Delta t}\varepsilon\tag{2} xt+Δt=xt+[vθ(xt,t)+2tσt2(xt+(1t)vθ(xt,t))]Δt+σtΔtε(2)
其中ε∼N(0,I),σt=at1−t\varepsilon \sim \mathcal{N}(0,I),\sigma_t = a\sqrt{\frac{t}{1 - t}}εN(0,I)σt=a1ttaaa是控制噪声水平的超参数。

依据正态分布的性质可知,Eq.(2),也就是πθ(xt−1∣xt,c)\pi_\theta(x_{t-1}|x_t,c)πθ(xt1xt,c)服从正态分布,那么很显然,我们可以直接KL散度为
DKL(π0∣∣πref)=∥xˉt+Δt,θ−xˉt+Δt,ref∥22σt2Δt=Δt2(σt(1−t)2t+1σt)2∥vθ(xt,t)−vref(xt,t)∥2(3)D_{KL}(\pi_0||\pi_{ref})=\frac{\Vert \bar x_{t+\Delta t,\theta} - \bar x_{t + \Delta t ,ref} \Vert^2}{2\sigma_t^2\Delta t} = \frac{ \Delta t}{2}\left( \frac{\sigma_t(1-t)}{2t} +\frac{1}{\sigma_t} \right)^2\Vert v_\theta(x_t,t) - v_{ref}(x_t,t) \Vert^2\tag{3} DKL(π0∣∣πref)=2σt2Δtxˉt+Δt,θxˉt+Δt,ref2=2Δt(2tσt(1t)+σt1)2vθ(xt,t)vref(xt,t)2(3)

这里直接代入KL散度公式来计算即可

在这里插入图片描述

5 Denoising Reduction

为了生成高质量的图像,Flow matching通常需要很多的去噪步骤,这使得RL训练的数据收集成本非常高。

论文发现,在进行RL训练的时候,是不需要太多的采样步数的,而在推理的时候保持原始的采样步依然能够获取高质量的样本。

为此,以SD3.5为例,在进行RL训练的时候,令采样时间步T=10;而在推理的时候,保持SD3.5默认设置T=40。

6 模型图

模型的训练流程见下图:

在这里插入图片描述

首先,分别采样5个高斯白噪声s0s_0s0,将提示词“A photo of four cups”作为条件,使用SDE数值求解器采样(T=10)得到sTs_TsT。然后将sTs_TsT送进奖励模型,得到R1,R2,R3,R4,⋯,RGR^1,R^2,R^3,R^4,\cdots,R^GR1,R2,R3,R4,,RG作为奖励。用这些奖励根据上面的优势函数计算优势得到A^1,A^2,A^3,A^4,⋯,A^G\hat A^1,\hat A^2,\hat A^3,\hat A^4,\cdots,\hat A^GA^1,A^2,A^3,A^4,,A^G,最后送进Flow-GRPO的损失函数计算损失即可。

7 数学证明

7.1 Eq.(1)证明

要将ODE转换成对应的SDE,就要先从ODE开始,我们有
dxt=vtdt(4)dx_t = v_tdt\tag{4} dxt=vtdt(4)
依据先前讲过的SDE,我们有对应的方程
dxx=fSDE(xt,t)dt+σtdw(5)dx_x = f_{\text{SDE}}(x_t,t)dt +\sigma_td w\tag{5} dxx=fSDE(xt,t)dt+σtdw(5)
我们需要求出fSDEf_{\text{SDE}}fSDEvtv_tvt的关系式

依据Flow matching所提到的FP方程,Eq.(4)和Eq.(5)都有一个对应的连续性方程来表达概率密度路径ptp_tpt。对于Eq.(5),就是对应的FP方程(证明过程见什么是Fokker-Planck方程),即
KaTeX parse error: Undefined control sequence: \part at position 2: \̲p̲a̲r̲t̲ ̲_tp_t(x) = -\na…
而Eq.(4)对应的连续性方程为:
∂tpt(x)=−∇⋅[vt(xt,t)pt(x)](7)\partial_t p_t(x) = -\nabla \cdot [v_t(x_t,t)p_t(x)]\tag{7} tpt(x)=[vt(xt,t)pt(x)](7)
ptp_tptvtv_tvt的关系满足Eq.(7),则我们说向量场vvv能够生成对应的路径ptp_tpt。Eq.(6)同理。

那么接下来就简单了,联立Eq.(6)和Eq.(7)
−∇⋅[fSDE(xt,t)pt(x)]+12∇2[σt2pt(x)]=−∇⋅[vt(xt,t)pt(x)](8)-\nabla \cdot [f_{\text{SDE}}(x_t,t)p_t(x)]+\frac{1}{2}\nabla^2[\sigma_t^2p_t(x)] =-\nabla \cdot [v_t(x_t,t)p_t(x)]\tag{8} [fSDE(xt,t)pt(x)]+212[σt2pt(x)]=[vt(xt,t)pt(x)](8)
因为
∇log⁡pt(x)=1pt(x)⋅∇pt(x)→∇pt(x)=pt(x)⋅∇log⁡pt(x)\nabla \log p_t(x) = \frac{1}{p_t(x)} \cdot \nabla p_t(x) \to \nabla p_t(x) = p_t(x)\cdot\nabla \log p_t(x) logpt(x)=pt(x)1pt(x)pt(x)=pt(x)logpt(x)
对Eq.(8)左侧第二项进行一下变化
∇2[σt2pt(x)]=σt2∇2pt(x)=σt2∇⋅(∇pt(x))=σt2∇⋅(pt(x)∇log⁡pt(x))\begin{aligned} \nabla^2[\sigma_t^2p_t(x)] = &\sigma_t^2\nabla^2p_t(x) \\=& \sigma_t^2\nabla\cdot (\nabla p_t(x)) \\= & \sigma_t^2\nabla \cdot (p_t(x)\nabla \log p_t(x)) \end{aligned} 2[σt2pt(x)]===σt22pt(x)σt2(pt(x))σt2(pt(x)logpt(x))
所以Eq.(8)等于:
−∇⋅[fSDE(xt,t)pt(x)]+12σt2∇⋅(pt(x)∇log⁡pt(x))=−∇⋅[vt(xt,t)pt(x)]−fSDE(xt,t)pt(x)+12σt2pt(x)∇log⁡pt(x)=−vt(xt,t)pt(x)fSDE(xt,t)pt(x)=vt(xt,t)pt(x)+12σt2pt(x)∇log⁡pt(x)fSDE(xt,t)=vt(xt,t)+12σt2∇log⁡pt(x)(9)\begin{aligned} -\nabla \cdot [f_{\text{SDE}}(x_t,t)p_t(x)]+\frac{1}{2}\sigma_t^2\nabla \cdot (p_t(x)\nabla \log p_t(x)) &=-\nabla \cdot [v_t(x_t,t)p_t(x)] \\ -f_{\text{SDE}}(x_t,t)p_t(x) + \frac{1}{2}\sigma_t^2p_t(x)\nabla \log p_t(x) &= - v_t(x_t,t)p_t(x) \\ f_{\text{SDE}}(x_t,t)p_t(x) &= v_t(x_t,t)p_t(x) + \frac{1}{2}\sigma_t^2p_t(x)\nabla \log p_t(x) \\ f_{\text{SDE}}(x_t,t) &= v_t(x_t,t) + \frac{1}{2}\sigma_t^2\nabla \log p_t(x) \end{aligned} \tag{9} [fSDE(xt,t)pt(x)]+21σt2(pt(x)logpt(x))fSDE(xt,t)pt(x)+21σt2pt(x)logpt(x)fSDE(xt,t)pt(x)fSDE(xt,t)=[vt(xt,t)pt(x)]=vt(xt,t)pt(x)=vt(xt,t)pt(x)+21σt2pt(x)logpt(x)=vt(xt,t)+21σt2logpt(x)(9)
这样的话,我们就得到了fSDEf_{\text{SDE}}fSDEvtv_tvt的关系式了

依据Score-Based Generative Modeling through Stochastic Differential Equations,正向过程Eq.(5)有对应的反向过程为
dxt=[f(xt,t)−g2(t)∇log⁡pt(xt)]dt+g(t)dwˉ(10)dx_t = [f(x_t,t)-g^2(t)\nabla\log p_t(x_t)]dt + g(t)d\bar w\tag{10} dxt=[f(xt,t)g2(t)logpt(xt)]dt+g(t)dwˉ(10)
其中,在本篇文章中,我们是让g(t)=σtg(t) = \sigma_tg(t)=σt,将Eq.(9)代入至Eq.(10)
dxt=[vt(xt,t)+12σt2∇log⁡pt(xt)−σt2∇log⁡pt(xt)]dt+σtdwˉdxt=[vt(xt,t)−σt22∇log⁡pt(xt)]dt+σtdwˉ(11)\begin{aligned} dx_t = & \left[v_t(x_t,t) + \frac{1}{2}\sigma_t^2\nabla \log p_t(x_t) - \sigma_t^2\nabla\log p_t(x_t)\right]dt + \sigma_td\bar w \\dx_t = & \left[v_t(x_t,t)-\frac{\sigma_t^2}{2}\nabla\log p_t(x_t)\right]dt + \sigma_td\bar w \end{aligned}\tag{11} dxt=dxt=[vt(xt,t)+21σt2logpt(xt)σt2logpt(xt)]dt+σtdwˉ[vt(xt,t)2σt2logpt(xt)]dt+σtdwˉ(11)
对于Eq.(11),已知vtv_tvt,一旦∇log⁡pt(xt)\nabla \log p_t(x_t)logpt(xt)也是已知的,那么就没有未知变量了,也就可以使用数值求解器生成样本了。因此我们还需要求解∇log⁡pt(xt)\nabla \log p_t(x_t)logpt(xt)

对于前向加噪过程,我们有xt=αtx0+βtx1x_t = \alpha_t x_0 + \beta_t x_1xt=αtx0+βtx1,在本期的Flow中,我们将加噪过程定义为αt=1−t;β=t\alpha_t = 1 - t;\beta = tαt=1t;β=txtx_txt服从的概率分布为(假设一维的情况)
pt∣0(xt∣x0)=N(xt∣atx0,βt2I)=1βt2πexp⁡{−(xt−atx0)22βt2}p_{t|0}(x_t|x_0) = \mathcal{N}(x_t|a_tx_0,\beta_t^2I) = \frac{1}{\beta_t\sqrt{2\pi}}\exp\{-\frac{(x_t-a_tx_0)^2}{2\beta_t^2}\} pt∣0(xtx0)=N(xtatx0,βt2I)=βt2π1exp{2βt2(xtatx0)2}
其对数结果为
log⁡pt∣0(xt∣x0)=log⁡(1βt2πexp⁡{−(xt−atx0)22βt2})=log⁡1βt2π−(xt−atx0)22βt2\begin{aligned} \log p_{t|0}(x_t|x_0) = &\log \left( \frac{1}{\beta_t\sqrt{2\pi}}\exp\{-\frac{(x_t-a_tx_0)^2}{2\beta_t^2}\} \right) \\= &\log \frac{1}{\beta_t\sqrt{2\pi}} -\frac{(x_t-a_tx_0)^2}{2\beta_t^2} \end{aligned} logpt∣0(xtx0)==log(βt2π1exp{2βt2(xtatx0)2})logβt2π12βt2(xtatx0)2
所以
∇log⁡pt∣0(xt∣x0)=−xt−αtx0βt2=βtx1βt2=−x1βt\nabla\log p_{t|0}(x_t|x_0) = -\frac{x_t - \alpha_tx_0}{\beta_t^2} = \frac{\beta_tx_1}{\beta_t^2} = -\frac{x_1}{\beta_t} logpt∣0(xtx0)=βt2xtαtx0=βt2βtx1=βtx1
因此
∇log⁡pt(xt)=1pt(xt)∇pt(xt)=1pt(xt)∇∫pt,0(xt,x0)dx0=1pt(xt)∫∇pt,0(xt,x0)dx0=1pt(xt)∫∇[pt∣0(xt∣x0)p0(x0)]dx0=1pt(xt)∫p0(x0)∇pt∣0(xt∣x0)dx0=1pt(xt)∫p0(x0)⋅pt∣0(xt∣x0)∇log⁡pt∣0(xt∣x0)dx0=1pt(xt)∫pt,0(xt,x0)∇log⁡pt∣0(xt∣x0)dx0=1pt(xt)∫p0∣t(x0∣xt)pt(xt)∇log⁡pt∣0(xt∣x0)dx0=∫p0∣t(x0∣xt)∇log⁡pt∣0(xt∣x0)dx0=∫x0∫x1p0∣t(x0,x1∣xt)dx1∇log⁡pt∣0(xt∣x0)dx0=∫x0∫x1p0∣t(x0,x1∣xt)∇log⁡pt∣0(xt∣x0)dx1dx0=E[∇log⁡pt∣0(xt∣x0)∣xt]=E[−x1βt∣xt]=−1βtE[x1∣xt](12)\begin{aligned} \nabla \log p_t(x_t) = & \frac{1}{p_t(x_t)}\nabla p_t(x_t) \\ = & \frac{1}{p_t(x_t)}\nabla\int p_{t,0}(x_t,x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int \nabla p_{t,0}(x_t,x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int \nabla \left[p_{t|0}(x_t|x_0)p_0(x_0)\right]dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_0(x_0) \nabla p_{t|0}(x_t|x_0)dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_0(x_0) \cdot p_{t|0}(x_t|x_0)\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_{t,0}(x_t,x_0) \nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \frac{1}{p_t(x_t)}\int p_{0|t}(x_0|x_t)p_t(x_t) \nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int p_{0|t}(x_0|x_t)\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int_{x_0}\int_{x_1} p_{0|t}(x_0,x_1|x_t)dx_1\nabla\log p_{t|0}(x_t|x_0) dx_0 \\ = & \int_{x_0}\int_{x_1} p_{0|t}(x_0,x_1|x_t) \nabla\log p_{t|0}(x_t|x_0) dx_1 dx_0 \\ = & \mathbb{E}\left[ \nabla \log p_{t|0}(x_t|x_0) |x_t\right] \\ = & \mathbb{E}\left[ -\frac{x_1}{\beta_t} |x_t\right] \\ = & -\frac{1}{\beta_t}\mathbb{E}\left[ x_1|x_t\right] \end{aligned}\tag{12} logpt(xt)==============pt(xt)1pt(xt)pt(xt)1pt,0(xt,x0)dx0pt(xt)1pt,0(xt,x0)dx0pt(xt)1[pt∣0(xtx0)p0(x0)]dx0pt(xt)1p0(x0)pt∣0(xtx0)dx0pt(xt)1p0(x0)pt∣0(xtx0)logpt∣0(xtx0)dx0pt(xt)1pt,0(xt,x0)logpt∣0(xtx0)dx0pt(xt)1p0∣t(x0xt)pt(xt)logpt∣0(xtx0)dx0p0∣t(x0xt)logpt∣0(xtx0)dx0x0x1p0∣t(x0,x1xt)dx1logpt∣0(xtx0)dx0x0x1p0∣t(x0,x1xt)logpt∣0(xtx0)dx1dx0E[logpt∣0(xtx0)xt]E[βtx1xt]βt1E[x1xt](12)
对于向量场vvv,在我们之前的表达式中,是有vt=x1−x0v_t = x_1 - x_0vt=x1x0。然而,由于路径存在交叉点,所以我们之前说过,我们学习到的vθv_\thetavθ其实并不等于vtv_tvt,而是vtv_tvt的数学期望。我们可以通过以下来证明:
L=∫01Ex0,x1[∥x1−x0−vθ(xt,t)∥2]=∫01Ex0,x1[∥x1−x0∥2+∣∣vθ(xt,t)∣∣2−2(x1−x0)Tvθ(xt,t)]dt=∫01{Ex0,x1[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C=∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ex0,x1[(x1−x0)Tvθ(xt,t)]}dt+C(13)\begin{aligned} \mathcal{L} = & \int_0^1 \mathbb{E}_{x_0,x_1}\left[ \Vert x_1-x_0 - v_\theta(x_t,t)\Vert^2 \right] \\ = & \int_0^1 \mathbb{E}_{x_0,x_1}\left[ \Vert x_1-x_0 \Vert^2 + ||v_\theta(x_t,t)||^2- 2(x_1 - x_0)^Tv_\theta(x_t,t) \right]dt \\ = & \int_0^1 \left\{\mathbb{E}_{x_0,x_1}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right]\right\}dt + C \\ = & \int_0^1 \left\{\mathbb{E}_{x_t}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right]\right\}dt + C \end{aligned}\tag{13} L====01Ex0,x1[x1x0vθ(xt,t)2]01Ex0,x1[x1x02+∣∣vθ(xt,t)22(x1x0)Tvθ(xt,t)]dt01{Ex0,x1[∣∣vθ(xt,t)2]2Ex0,x1[(x1x0)Tvθ(xt,t)]}dt+C01{Ext[∣∣vθ(xt,t)2]2Ex0,x1[(x1x0)Tvθ(xt,t)]}dt+C(13)
第一项是因为给定x0,x1x_0,x_1x0,x1,有xt=tx1+(1−t)x0x_t = tx_1 + (1-t)x_0xt=tx1+(1t)x0,所以可以直接写成关于xtx_txt的数学期望。

第二项我们可以继续变化,由全期望公式:EY=EX[EY(Y∣X)]\mathbb{E}Y = \mathbb{E}_X[\mathbb{E}_Y(Y|X)]EY=EX[EY(YX)],可得
Ex0,x1[(x1−x0)Tvθ(xt,t)]=Ext[Ex0,x1[(x1−x0)Tvθ(xt,t)∣xt]]=Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]\begin{aligned} \mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) \right] = & \mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0)^Tv_\theta(x_t,t) |x_t\right]] \\ = & \mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0) |x_t\right]^Tv_\theta(x_t,t)] \end{aligned} Ex0,x1[(x1x0)Tvθ(xt,t)]==Ext[Ex0,x1[(x1x0)Tvθ(xt,t)xt]]Ext[Ex0,x1[(x1x0)xt]Tvθ(xt,t)]
所以Eq.(13)为
L=∫01{Ext[∣∣vθ(xt,t)∣∣2]−2Ext[Ex0,x1[(x1−x0)∣xt]Tvθ(xt,t)]}dt+C=∫01Ext[∥Ex0,x1[x1−x0∣xt]−vθ(xt,t)∥2]dt+C′(14)\begin{aligned} \mathcal{L} = &\int_0^1 \left\{\mathbb{E}_{x_t}\left[ ||v_\theta(x_t,t)||^2\right] -2\mathbb{E}_{x_t}[\mathbb{E}_{x_0,x_1}\left[(x_1 - x_0) |x_t\right]^Tv_\theta(x_t,t)]\right\}dt + C \\ = & \int_0^1 \mathbb{E}_{x_t}\left[ \Vert \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] - v_\theta(x_t,t) \Vert^2 \right]dt +C' \end{aligned}\tag{14} L==01{Ext[∣∣vθ(xt,t)2]2Ext[Ex0,x1[(x1x0)xt]Tvθ(xt,t)]}dt+C01Ext[Ex0,x1[x1x0xt]vθ(xt,t)2]dt+C(14)
此时我们不难看出,我们所学习到的vθ(xt,t)=Ex0,x1[x1−x0∣xt]v_\theta(x_t,t) = \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t]vθ(xt,t)=Ex0,x1[x1x0xt]

我们继续转化
vθ(xt,t)=Ex0,x1[x1−x0∣xt]=Ex0,x1[x1∣xt]−Ex0,x1[x0∣xt]=Ex0,x1[x1∣xt]−Ex0,x1[xt−tx11−t∣xt]=Ex0,x1[x1∣xt]−Ex0,x1[xt1−t∣xt]+Ex0,x1[tx11−t∣xt]=Ex0,x1[x1∣xt]−xt1−t+t1−tEx0,x1[x1∣xt]=−xt1−t+11−tEx0,x1[x1∣xt]=−xt1−t+11−t⋅(−βt∇log⁡pt(xt))=−xt1−t−t1−t⋅∇log⁡pt(xt)\begin{aligned} v_\theta(x_t,t) = & \mathbb{E}_{x_0,x_1}[x_1-x_0|x_t] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}[x_0|x_t] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}\left[\frac{x_t-tx_1}{1-t}|x_t\right] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] - \mathbb{E}_{x_0,x_1}\left[\frac{x_t}{1-t}|x_t\right] + \mathbb{E}_{x_0,x_1}\left[\frac{tx_1}{1-t}|x_t\right] \\ = & \mathbb{E}_{x_0,x_1}[x_1|x_t] -\frac{x_t}{1-t}+ \frac{t}{1-t}\mathbb{E}_{x_0,x_1}\left[x_1|x_t\right] \\ = & -\frac{x_t}{1-t} + \frac{1}{1-t}\mathbb{E}_{x_0,x_1}\left[x_1|x_t\right] \\ = & -\frac{x_t}{1-t} + \frac{1}{1-t}\cdot (-\beta_t\nabla \log p_t(x_t)) \\ = & -\frac{x_t}{1-t} - \frac{t}{1-t}\cdot \nabla \log p_t(x_t) \end{aligned} vθ(xt,t)========Ex0,x1[x1x0xt]Ex0,x1[x1xt]Ex0,x1[x0xt]Ex0,x1[x1xt]Ex0,x1[1txttx1xt]Ex0,x1[x1xt]Ex0,x1[1txtxt]+Ex0,x1[1ttx1xt]Ex0,x1[x1xt]1txt+1ttEx0,x1[x1xt]1txt+1t1Ex0,x1[x1xt]1txt+1t1(βtlogpt(xt))1txt1ttlogpt(xt)
∇\nabla单独放等式左侧可得
∇log⁡pt(xt)=−xt−1−ttvθ(xt,t)\nabla\log p_t(x_t) = -\frac{x}{t}-\frac{1-t}{t}v_\theta(x_t,t) logpt(xt)=txt1tvθ(xt,t)
把它代入到Eq.(11)可得最终的表达式
dxt=[vt(xt,t)−σt22(−xt−1−ttvθ(xt,t))]dt+σtdwˉdxt=[vt(xt,t)+σt22t(x+(1−t)vθ(xt,t))]dt+σtdwˉ(15)\begin{aligned} dx_t = & \left[v_t(x_t,t)-\frac{\sigma_t^2}{2}\left( -\frac{x}{t}-\frac{1-t}{t}v_\theta(x_t,t) \right)\right]dt + \sigma_td\bar w\\ dx_t = & \left[v_t(x_t,t)+\frac{\sigma_t^2}{2t}\left( x+(1-t)v_\theta(x_t,t) \right)\right]dt + \sigma_td\bar w \end{aligned}\tag{15} dxt=dxt=[vt(xt,t)2σt2(txt1tvθ(xt,t))]dt+σtdwˉ[vt(xt,t)+2tσt2(x+(1t)vθ(xt,t))]dt+σtdwˉ(15)
至此得证

8 参考

[1] 深入理解Rectified Flow,完善统一扩散框架 - 知乎

9 结束

好了,本期内容到此为止了,如有问题,还望指出,阿里嘎多!
在这里插入图片描述

http://www.dtcms.com/a/335314.html

相关文章:

  • 概率论基础教程第3章条件概率与独立性(二)
  • 如何解决C盘存储空间被占的问题,请看本文
  • C语言零基础第18讲:自定义类型—结构体
  • 9.从零开始写LINUX内核——设置中断描述符表
  • duiLib 实现鼠标拖动标题栏时,窗口跟着拖动
  • 深入了解 swap:作用、局限与分区建立
  • Linux sar命令详细使用指南
  • Effective C++ 条款45:运用成员函数模板接受所有兼容类型
  • Day2--滑动窗口与双指针--2090. 半径为 k 的子数组平均值,2379. 得到 K 个黑块的最少涂色次数,2841. 几乎唯一子数组的最大和
  • Linux软件编程:线程间通信
  • 【FreeRTOS】队列集
  • MySQL 插入数据提示字段超出范围?一招解决 DECIMAL 类型踩坑
  • 第三十七天(js前端数据加密和混淆)
  • Fixture Caliper 工具
  • GRPO(Group Relative Policy Optimization)公式速览
  • Scala面试题及详细答案100道(11-20)-- 函数式编程基础
  • 嵌入式软件架构设计之九: 双机通信之通信方式
  • 排列与组合
  • 超详细yolov8/11-obb旋转框全流程概述:配置环境、数据标注、训练、验证/预测、onnx部署(c++/python)详解
  • STM32标准库学习笔记
  • MM-Spatial和Spatial-MLLM论文解读
  • 【力扣-多数元素 JAVA/Python】
  • CD4+ T细胞激活区分抗PD-L1联合抗CTLA4疗法与单药抗PD-L1治疗的响应差异-空间最近邻分析
  • 民法学学习笔记(个人向) Part.5
  • 【最后203篇系列】032 OpenAI格式调用多模型实验
  • 39.离散化与哈希
  • 数据结构:二叉树的遍历 (Binary Tree Traversals)
  • 杂记 03
  • v-scale-scree: 根据屏幕尺寸缩放内容
  • 基于Python的电影评论数据分析系统 Python+Django+Vue.js