流匹配(Flow Matching)的生成过程:求解反向常微分方程(ODE)
流匹配(Flow Matching):从理论到实践
引言
近年来,扩散模型(Diffusion Models)成为生成建模领域的热门技术,它们依靠逐步加噪和去噪的方式生成高质量数据。然而,扩散模型在采样时需要数百甚至上千步去噪,计算开销较大。相比之下,流匹配(Flow Matching) 提供了一种新的视角,它直接学习数据的演化流(vector field),可以通过求解常微分方程(ODE)高效生成数据,减少推理时间。
在本篇博客中,我们将详细介绍流匹配的基本原理 (更深入的理解请移步笔者的另一篇博客流匹配(Flow Matching)教程),并结合 PyTorch 代码实现流匹配的训练和生成过程,帮助你掌握这一技术。
1. 流匹配的基本原理
流匹配是一种基于常微分方程(ODE)的生成方法,其核心思想是学习一个连续时间演化的向量场 ( v ( x , t ) \mathbf{v}(x, t) v(x,t) ),使数据从噪声流动到目标分布。数学上,这个演化过程可以描述为:
d x ( t ) d t = v ( x ( t ) , t ) \frac{d\mathbf{x}(t)}{dt} = \mathbf{v}(\mathbf{x}(t), t) dtdx(t)=v(x(t),t)
其中:
- ( x ( t ) \mathbf{x}(t) x(t) ) 表示数据在时间 ( t t t ) 的状态;
- ( v ( x , t ) \mathbf{v}(x, t) v(x,t) ) 是向量场,描述了数据随时间的变化趋势;
- 该方程的初始条件通常设定为噪声,最终状态 ( x ( 0 ) \mathbf{x}(0) x(0) ) 应该符合目标数据分布。
1.1 训练:学习向量场
流匹配的训练目标是让模型学习到正确的向量场,使得在任意时间点 ( t t t ),数据都沿着正确的轨迹演化。具体来说,我们通过构造一个参考概率路径(Reference Probability Path) ( x t x_t xt ),然后训练模型的输出 ( v θ ( x , t ) \mathbf{v}_\theta(x, t) vθ(x,t) ) 逼近真实速度 ( u ( x t , t ) \mathbf{u}(x_t, t) u(xt,t) )。
训练损失函数为均方误差(MSE):
L
(
θ
)
=
E
x
1
,
t
[
∥
v
θ
(
x
t
,
t
)
−
u
(
x
t
,
t
)
∥
2
]
\mathcal{L}(\theta) = \mathbb{E}_{x_1, t} \left[ \|\mathbf{v}_\theta(x_t, t) - \mathbf{u}(x_t, t)\|^2 \right]
L(θ)=Ex1,t[∥vθ(xt,t)−u(xt,t)∥2]
其中:
- ( x 1 x_1 x1 ) 来自目标数据分布;
- ( x t x_t xt ) 是根据已知路径生成的中间状态;
- ( u ( x t , t ) \mathbf{u}(x_t, t) u(xt,t) ) 是真实的速度。
2. 代码实现:训练流匹配模型
下面的代码实现了一个二维向量场网络,用于训练流匹配模型,使其能够学习数据从噪声流动到目标分布的路径。
具体原理和代码注释请参考笔者的另一篇博客:流匹配(Flow Matching)教程,这里主要在于看一下它是如何反向求解ODE的。
2.1 定义向量场网络
我们使用一个简单的多层感知机(MLP)来拟合 ( v ( x , t ) \mathbf{v}(x, t) v(x,t) ):
import torch
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(42)
# 定义二维向量场网络
class VectorField(nn.Module):
def __init__(self):
super(VectorField, self).__init__()
self.net = nn.Sequential(
nn.Linear(2 + 1, 128), # 输入:x1, x2, t
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 2) # 输出:u1, u2
)
def forward(self, t, x):
t_input = t.unsqueeze(-1).expand(x.size(0), 1)
inp = torch.cat([t_input, x], dim=1)
return self.net(inp)
2.2 采样路径:构造条件概率路径
为了训练流匹配模型,我们需要一个参考路径,即从目标数据分布 (
q
(
x
1
)
q(x_1)
q(x1) ) 到噪声的路径。这里我们使用线性插值路径:
x
t
=
t
x
1
+
(
1
−
t
)
ϵ
,
ϵ
∼
N
(
0
,
I
)
x_t = t x_1 + (1-t) \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
xt=tx1+(1−t)ϵ,ϵ∼N(0,I)
并计算对应的真实速度场:
# 生成条件路径 x_t 和真实速度 u_true
def sample_conditional_path(x1, t, sigma_min=0.0):
t = t.view(-1, 1)
mu_t = t * x1
sigma_t = (1 - t) + t * sigma_min
eps = torch.randn_like(x1)
x_t = mu_t + sigma_t * eps
dot_mu_t = x1
dot_sigma_t = -1.0 + sigma_min
term1 = (dot_sigma_t / sigma_t) * (x_t - mu_t)
term2 = dot_mu_t
u_true = term1 + term2
return x_t, u_true
2.3 训练循环
训练目标是让模型输出 ( v θ ( x , t ) \mathbf{v}_\theta(x, t) vθ(x,t) ) 尽可能逼近真实速度 ( u ( x t , t ) u(x_t, t) u(xt,t) ):
# 训练
device = torch.device('cpu')
model = VectorField().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 目标数据分布(两个高斯团)
def sample_target_data(batch_size):
comp = torch.randint(0, 2, (batch_size, 1))
centers = torch.tensor([[-4.0, 0.0], [4.0, 0.0]])
center = centers[comp.view(-1)]
return center + torch.randn(batch_size, 2)
# 训练循环
for step in range(10000):
model.train()
batch_size = 128
t = torch.rand(batch_size, device=device)
x1 = sample_target_data(batch_size).to(device)
x_t, u_true = sample_conditional_path(x1, t)
u_pred = model(t, x_t)
loss = ((u_pred - u_true) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 1000 == 0:
print(f"Step {step}, loss={loss.item():.4f}")
3. 生成过程:求解反向 ODE
训练完成后,我们需要从噪声 ( x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I) ) 生成数据 ( x 0 x_0 x0 )。由于 ODE 具有时间可逆性,我们可以使用 ODE 求解器(如 Runge-Kutta 方法)反向求解:
from torchdiffeq import odeint
# 反向ODE函数
class ReverseODEFunc(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, t, x):
return -self.model(t, x)
# 生成样本
def generate_samples(model, batch_size=1024, timesteps=100):
t_eval = torch.linspace(1, 0, timesteps) # 反向时间步
xT = torch.randn(batch_size, 2) # 从高斯噪声开始
ode_func = ReverseODEFunc(model)
x0 = odeint(ode_func, xT, t_eval, method='rk4')[-1] # 反向求解ODE
return x0.detach().cpu()
# 生成数据
samples = generate_samples(model)
4. 总结
✅ 流匹配学习向量场,从而能够通过ODE求解器高效生成数据
✅ 相比扩散模型,流匹配可以减少采样步数,提高生成速度
✅ 本文代码实现了流匹配的训练和生成,完整复现了该方法
流匹配是生成建模的新方向,随着研究的深入,它有望成为扩散模型之外的另一种强大工具!
后记
2025年2月26日20点39分于上海,在GPT 4o大模型辅助下完成。