当前位置: 首页 > news >正文

流匹配(Flow Matching)的生成过程:求解反向常微分方程(ODE)

流匹配(Flow Matching):从理论到实践

引言

近年来,扩散模型(Diffusion Models)成为生成建模领域的热门技术,它们依靠逐步加噪和去噪的方式生成高质量数据。然而,扩散模型在采样时需要数百甚至上千步去噪,计算开销较大。相比之下,流匹配(Flow Matching) 提供了一种新的视角,它直接学习数据的演化流(vector field),可以通过求解常微分方程(ODE)高效生成数据,减少推理时间。

在本篇博客中,我们将详细介绍流匹配的基本原理 (更深入的理解请移步笔者的另一篇博客流匹配(Flow Matching)教程),并结合 PyTorch 代码实现流匹配的训练和生成过程,帮助你掌握这一技术。


1. 流匹配的基本原理

流匹配是一种基于常微分方程(ODE)的生成方法,其核心思想是学习一个连续时间演化的向量场 ( v ( x , t ) \mathbf{v}(x, t) v(x,t) ),使数据从噪声流动到目标分布。数学上,这个演化过程可以描述为:

d x ( t ) d t = v ( x ( t ) , t ) \frac{d\mathbf{x}(t)}{dt} = \mathbf{v}(\mathbf{x}(t), t) dtdx(t)=v(x(t),t)

其中:

  • ( x ( t ) \mathbf{x}(t) x(t) ) 表示数据在时间 ( t t t ) 的状态;
  • ( v ( x , t ) \mathbf{v}(x, t) v(x,t) ) 是向量场,描述了数据随时间的变化趋势;
  • 该方程的初始条件通常设定为噪声,最终状态 ( x ( 0 ) \mathbf{x}(0) x(0) ) 应该符合目标数据分布。

1.1 训练:学习向量场

流匹配的训练目标是让模型学习到正确的向量场,使得在任意时间点 ( t t t ),数据都沿着正确的轨迹演化。具体来说,我们通过构造一个参考概率路径(Reference Probability Path) ( x t x_t xt ),然后训练模型的输出 ( v θ ( x , t ) \mathbf{v}_\theta(x, t) vθ(x,t) ) 逼近真实速度 ( u ( x t , t ) \mathbf{u}(x_t, t) u(xt,t) )。

训练损失函数为均方误差(MSE):
L ( θ ) = E x 1 , t [ ∥ v θ ( x t , t ) − u ( x t , t ) ∥ 2 ] \mathcal{L}(\theta) = \mathbb{E}_{x_1, t} \left[ \|\mathbf{v}_\theta(x_t, t) - \mathbf{u}(x_t, t)\|^2 \right] L(θ)=Ex1,t[vθ(xt,t)u(xt,t)2]
其中:

  • ( x 1 x_1 x1 ) 来自目标数据分布;
  • ( x t x_t xt ) 是根据已知路径生成的中间状态;
  • ( u ( x t , t ) \mathbf{u}(x_t, t) u(xt,t) ) 是真实的速度。

2. 代码实现:训练流匹配模型

下面的代码实现了一个二维向量场网络,用于训练流匹配模型,使其能够学习数据从噪声流动到目标分布的路径。
具体原理和代码注释请参考笔者的另一篇博客:流匹配(Flow Matching)教程,这里主要在于看一下它是如何反向求解ODE的。

2.1 定义向量场网络

我们使用一个简单的多层感知机(MLP)来拟合 ( v ( x , t ) \mathbf{v}(x, t) v(x,t) ):

import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(42)

# 定义二维向量场网络
class VectorField(nn.Module):
    def __init__(self):
        super(VectorField, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2 + 1, 128),  # 输入:x1, x2, t
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 2)  # 输出:u1, u2
        )
    
    def forward(self, t, x):
        t_input = t.unsqueeze(-1).expand(x.size(0), 1)
        inp = torch.cat([t_input, x], dim=1)
        return self.net(inp)

2.2 采样路径:构造条件概率路径

为了训练流匹配模型,我们需要一个参考路径,即从目标数据分布 ( q ( x 1 ) q(x_1) q(x1) ) 到噪声的路径。这里我们使用线性插值路径
x t = t x 1 + ( 1 − t ) ϵ , ϵ ∼ N ( 0 , I ) x_t = t x_1 + (1-t) \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) xt=tx1+(1t)ϵ,ϵN(0,I)

并计算对应的真实速度场:

# 生成条件路径 x_t 和真实速度 u_true
def sample_conditional_path(x1, t, sigma_min=0.0):
    t = t.view(-1, 1)
    mu_t = t * x1
    sigma_t = (1 - t) + t * sigma_min
    eps = torch.randn_like(x1)
    x_t = mu_t + sigma_t * eps

    dot_mu_t = x1
    dot_sigma_t = -1.0 + sigma_min
    term1 = (dot_sigma_t / sigma_t) * (x_t - mu_t)
    term2 = dot_mu_t
    u_true = term1 + term2
    return x_t, u_true

2.3 训练循环

训练目标是让模型输出 ( v θ ( x , t ) \mathbf{v}_\theta(x, t) vθ(x,t) ) 尽可能逼近真实速度 ( u ( x t , t ) u(x_t, t) u(xt,t) ):

# 训练
device = torch.device('cpu')
model = VectorField().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 目标数据分布(两个高斯团)
def sample_target_data(batch_size):
    comp = torch.randint(0, 2, (batch_size, 1))
    centers = torch.tensor([[-4.0, 0.0], [4.0, 0.0]])
    center = centers[comp.view(-1)]
    return center + torch.randn(batch_size, 2)

# 训练循环
for step in range(10000):
    model.train()
    batch_size = 128
    t = torch.rand(batch_size, device=device)
    x1 = sample_target_data(batch_size).to(device)
    x_t, u_true = sample_conditional_path(x1, t)

    u_pred = model(t, x_t)
    loss = ((u_pred - u_true) ** 2).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 1000 == 0:
        print(f"Step {step}, loss={loss.item():.4f}")

3. 生成过程:求解反向 ODE

训练完成后,我们需要从噪声 ( x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I) ) 生成数据 ( x 0 x_0 x0 )。由于 ODE 具有时间可逆性,我们可以使用 ODE 求解器(如 Runge-Kutta 方法)反向求解:

from torchdiffeq import odeint

# 反向ODE函数
class ReverseODEFunc(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, t, x):
        return -self.model(t, x)

# 生成样本
def generate_samples(model, batch_size=1024, timesteps=100):
    t_eval = torch.linspace(1, 0, timesteps)  # 反向时间步
    xT = torch.randn(batch_size, 2)  # 从高斯噪声开始
    ode_func = ReverseODEFunc(model)
    x0 = odeint(ode_func, xT, t_eval, method='rk4')[-1]  # 反向求解ODE
    return x0.detach().cpu()

# 生成数据
samples = generate_samples(model)

4. 总结

流匹配学习向量场,从而能够通过ODE求解器高效生成数据
相比扩散模型,流匹配可以减少采样步数,提高生成速度
本文代码实现了流匹配的训练和生成,完整复现了该方法

流匹配是生成建模的新方向,随着研究的深入,它有望成为扩散模型之外的另一种强大工具!

后记

2025年2月26日20点39分于上海,在GPT 4o大模型辅助下完成。

相关文章:

  • 单例模式——c++
  • JavaScript将:;隔开的字符串转换为json格式。使用正则表达式匹配键值对,并构建对象。多用于解析cssText为style Object对象
  • 基础知识|原型在什么时候用和类的区别
  • 机试刷题_HJ14 字符串排序【python】
  • CSS盒子模型
  • 算法每日一练 (6)
  • Python 类(创建和使用类)
  • 自然语言处理:初识自然语言处理
  • SQL基本知识
  • 代码随想录二刷|动态规划11
  • 最新版本SpringAI接入DeepSeek大模型,并集成Mybatis
  • Linux系统里怎么怎么截图
  • 低代码与开发框架的一些整合[3]
  • 超大规模分类(四):Partial FC
  • ReentrantLock 底层实现
  • 【git】【reset全解】Git 回到上次提交并处理提交内容的不同方式
  • AI智能体与大语言模型:重塑SaaS系统的未来航向
  • HTML篇
  • 区块链仿真工具SimBlock使用
  • PDF处理控件Aspose.PDF教程:使用 Python 将 PDF 转换为 TIFF
  • 金融监管总局:近五年民企贷款投放年平均增速比各项贷款平均增速高出1.1个百分点
  • 特色茶酒、非遗挂面……六安皋品入沪赴“五五购物节”
  • 特朗普要征100%关税,好莱坞这批境外摄制新片能躲过吗?
  • 抗战回望18︱《广西学生军》:“广西的政治基础是青年”
  • 农村青年寻路纪|劳动者的书信⑤
  • 美国加州州长:加州继续对中国“敞开贸易大门”