【代码】TorchCFM(Conditional Flow Matching library)代码入门
- 代码结构速览
- torchcfm/
- conditional_flow_matching.py
- Class ConditionalFlowMatcher
- class ExactOptimalTransportConditionalFlowMatcher
- class TargetConditionalFlowMatcher
- class SchrodingerBridgeConditionalFlowMatcher
- optimal_transport.py
- runner/
- src/train.py
- configs/**.yaml
- src/models/cfm_module.py
- src/models/runner.py
paper:Improving and generalizing flow-based generative models with minibatch optimal transport
📦 GitHub 仓库: https://github.com/atong01/conditional-flow-matching

代码结构速览
conditional-flow-matching/
├── torchcfm/
│ ├── models/ # 模型定义(神经网络架构)
│ │ └── unet # 主干网络
│ ├── conditional_flow_matching.py # Conditional Flow Matching 损失实现
│ ├── optimal_transport.py # 采样
│ └── utils.py/ # 辅助函数
├── examples/
│ ├── 2D_tutorials # 二维
│ ├── images # 图像生成示例
│ └── ...
├── runner/
└── README.md
torchcfm/
conditional_flow_matching.py
Class ConditionalFlowMatcher
在 CFM(I-CFM)的基本形式中,将 zzz 定义为一对随机变量,源点 x0x_0x0 和目标点 x1x_1x1,并设置 q(z)=q(x0)q(x1)q(z) = q(x_0)q(x_1)q(z)=q(x0)q(x1) 为独立耦合。让条件为 x0x_0x0 和 x1x_1x1 之间的高斯流,标准差为 σ\sigmaσ,定义为
pt(x∣z)=N(x∣tx1+(1−t)x0,σ2)(14)p_t(x|z) =\mathcal N \left(x | tx_1 + (1 − t)x_0, \sigma^2 \right) \tag{14}pt(x∣z)=N(x∣tx1+(1−t)x0,σ2)(14)
ut(x∣z)=x1−x0(15)u_t(x|z) = x_1 − x_0 \tag{15}ut(x∣z)=x1−x0(15)
ut(x∣z)u_t(x|z)ut(x∣z) 的公式源自将定理 2.1 应用于条件概率路径,其中 μt=tx1+(1−t)x0\mu_t = tx_1 + (1 − t)x_0μt=tx1+(1−t)x0 且 σt=σ\sigma_t = \sigmaσt=σ。
class ConditionalFlowMatcher:"Base class for conditional flow matching methods."def __init__(self, sigma: Union[float, int] = 0.0):self.sigma = sigmadef compute_mu_t(self, x0, x1, t):"Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), x0 represents the source minibatch, x1 represents the target minibatch, see (Eq.14) [1]."t = pad_t_like_x(t, x0)return t * x1 + (1 - t) * x0def compute_sigma_t(self, t):"Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]."del t # 基类直接忽略 t, 返回常数 self.sigmareturn self.sigmadef sample_xt(self, x0, x1, t, epsilon):"Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]."mu_t = self.compute_mu_t(x0, x1, t)sigma_t = self.compute_sigma_t(t)sigma_t = pad_t_like_x(sigma_t, x0)return mu_t + sigma_t * epsilondef compute_conditional_flow(self, x0, x1, t, xt):"ut : conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]."del t, xt # 基类实现不使用 t xt, 该变体是假定向量场在任何 t 任何中间点都是恒定的return x1 - x0def sample_noise_like(self, x):return torch.randn_like(x)def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):"Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]."if t is None:t = torch.rand(x0.shape[0]).type_as(x0)assert len(t) == x0.shape[0], "t has to have batch size dimension"eps = self.sample_noise_like(x0)xt = self.sample_xt(x0, x1, t, eps)ut = self.compute_conditional_flow(x0, x1, t, xt)if return_noise:return t, xt, ut, epselse:return t, xt, utdef compute_lambda(self, t):"Compute the lambda: score weighting function, see Eq.(23) [3]"sigma_t = self.compute_sigma_t(t)return 2 * sigma_t / (self.sigma**2 + 1e-8)
采样 sample_xt() 函数对应论文中的公式 (14),代码使用了 重参数化技巧(Reparameterization Trick)。由于是从高斯分布采样:x∼N(μt,σ2I)x \sim \mathcal{N}(\mu_t, \sigma^2 I)x∼N(μt,σ2I),可以直接用重参数化方式表达为:
x=μt+σ⋅ε,ε∼N(0,I)x = \mu_t + \sigma \cdot \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I) x=μt+σ⋅ε,ε∼N(0,I)
这是深度生成模型中的标准做法(VAE、扩散模型等都使用),优点是:采样过程可导(允许梯度回传),易于 GPU 并行化。所以代码 return mu_t + sigma_t * epsilon。
class ExactOptimalTransportConditionalFlowMatcher
class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):"Child class for optimal transport conditional flow matching method."def __init__(self, sigma: Union[float, int] = 0.0):super().__init__(sigma)self.ot_sampler = OTPlanSampler(method="exact")def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):"Compute the sample xt and the conditional vector field ut(x1|x0) with respect to the minibatch OT plan $\Pi$."x0, x1 = self.ot_sampler.sample_plan(x0, x1) # 先用 OT sampler 对 (x0,x1) 重新配对return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)def guided_sample_location_and_conditional_flow(self, x0, x1, y0=None, y1=None, t=None, return_noise=False):"y0 represents the source label minibatch, y1 represents the target label minibatch."x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)if return_noise:t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1, epselse:t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1
覆写 sample_location_and_conditional_flow:先用 OT sampler 对 (x0,x1)(x_0,x_1)(x0,x1) 重新配对,再调用父类的 sample_location_and_conditional_flow。这样意味着生成的 (x0,x1)(x_0,x_1)(x0,x1) 对是按照 minibatch OT 进行 最优配对(不是简单随机配对),通常能提高训练效果。
class TargetConditionalFlowMatcher
Example II: Optimal Transport conditional VFs. 条件概率路径的一个可以说更自然的选择是将均值和标准差定义为随时间线性变化,即 μt(x)=tx1,and σt(x)=1−(1−σmin)t(20)\mu_t(x) = tx_1 , \ \text{and} \ \sigma_t(x) = 1 − (1 − \sigma_{\text{min}})t \tag{20}μt(x)=tx1, and σt(x)=1−(1−σmin)t(20)根据定理 3,该路径由 VF
ut(x∣x1)=x1−(1−σmin)x1−(1−σmin)t(21)u_t(x|x_1) = \frac{x_1 − (1 − \sigma_{\text{min}}) x}{ 1 − (1 − \sigma_{\text{min}})t} \tag{21}ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x(21) 生成。
class TargetConditionalFlowMatcher(ConditionalFlowMatcher):"2023 Lipman et al. style target OT conditional flow matching."def compute_mu_t(self, x0, x1, t):"Compute the mean of the probability path tx1, see (Eq.20) [2]"del x0t = pad_t_like_x(t, x1)return t * x1def compute_sigma_t(self, t):"Compute the standard deviation of the probability path N(t x1, 1 - (1 - sigma) t), see (Eq.20) [2]."return 1 - (1 - self.sigma) * tdef compute_conditional_flow(self, x0, x1, t, xt):"Compute the conditional vector field ut(x|x1) = (x1 - (1 - sigma) x)/(1 - (1 - sigma) t ), see Eq.(21) [2]. xt represents the samples drawn from probability path pt."del x0t = pad_t_like_x(t, x1)return (x1 - (1 - self.sigma) * xt) / (1 - (1 - self.sigma) * t)
class SchrodingerBridgeConditionalFlowMatcher
我们将条件路径分布设定为一个在 x0x_0x0 和 x1x_1x1 之间、扩散尺度为 σ\sigmaσ 的布朗桥(Brownian Bridge)。该路径的概率分布与生成向量场定义如下:
pt(x∣z)=N(x∣tx1+(1−t)x0,t(1−t)σ2)(20)p_t(x \mid z) = \mathcal{N}\big(x \,\big|\, t x_1 + (1 - t)x_0,\; t(1 - t)\sigma^2 \big) \tag{20} pt(x∣z)=N(xtx1+(1−t)x0,t(1−t)σ2)(20)
ut(x∣z)=1−2t2t(1−t)(x−(tx1+(1−t)x0))+(x1−x0),(21)u_t(x \mid z) = \frac{1 - 2t}{2t(1 - t)} \big(x - (t x_1 + (1 - t)x_0)\big) + (x_1 - x_0), \tag{21} ut(x∣z)=2t(1−t)1−2t(x−(tx1+(1−t)x0))+(x1−x0),(21)
其中,utu_tut 由公式 (5) 计算,作为生成概率路径 pt(x∣z)p_t(x \mid z)pt(x∣z) 的向量场。边缘耦合(marginal coupling)π2σ2\pi_{2\sigma^2}π2σ2 与 ut(x∣z)u_t(x \mid z)ut(x∣z) 一起定义了 ut(x)u_t(x)ut(x),后者通过算法 4 中的回归目标进行近似。
class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):"Child class for Schrödinger bridge conditional flow matching method."def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):if sigma <= 0:raise ValueError(f"Sigma must be strictly positive, got {sigma}.")elif sigma < 1e-3:warnings.warn("Small sigma values may lead to numerical instability.")super().__init__(sigma)self.ot_method = ot_methodself.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)def compute_sigma_t(self, t):"Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sqrt(t * (1 - t))*sigma^2), see (Eq.20) [1]."return self.sigma * torch.sqrt(t * (1 - t))def compute_conditional_flow(self, x0, x1, t, xt):"Compute the conditional vector field. ut(x1|x0) = (1 - 2 * t) / (2 * t * (1 - t)) * (xt - mu_t) + x1 - x0, see Eq.(21) [1]."t = pad_t_like_x(t, x0)mu_t = self.compute_mu_t(x0, x1, t)sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t) + 1e-8)ut = sigma_t_prime_over_sigma_t * (xt - mu_t) + x1 - x0return utdef sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):"Compute the sample xt and the conditional vector field ut(x1|x0) with respect to the minibatch entropic OT plan."x0, x1 = self.ot_sampler.sample_plan(x0, x1)return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)def guided_sample_location_and_conditional_flow(self, x0, x1, y0=None, y1=None, t=None, return_noise=False):x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)if return_noise:t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1, epselse:t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1
optimal_transport.py
class OTPlanSampler:"OTPlanSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost) with different implementations of the plan calculation."def __init__(self,method: str,reg: float = 0.05,reg_m: float = 1.0,normalize_cost: bool = False,num_threads: Union[int, str] = 1,warn: bool = True,) -> None:if method == "exact":self.ot_fn = partial(pot.emd, numThreads=num_threads)elif method == "sinkhorn":self.ot_fn = partial(pot.sinkhorn, reg=reg)elif method == "unbalanced":self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m)elif method == "partial":self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg)else:raise ValueError(f"Unknown method: {method}")self.reg = regself.reg_m = reg_mself.normalize_cost = normalize_costself.warn = warndef get_map(self, x0, x1):"Compute the OT plan (wrt squared Euclidean cost) between a source and a target minibatch."
method:指定使用哪种 OT 求解器,代码支持的值为 exact、sinkhorn、unbalanced、partial。
-
exact:精确地求解 Earth Mover’s Distance(通常基于线性规划 / Hungarian / network flow)。适合较小 batch,精度高但慢。 -
sinkhorn:基于熵正则化的 Sinkhorn 算法(可扩展到较大规模,需设置reg)。 -
unbalanced:用于不守恒质量(mass)情形的 unbalanced Sinkhorn(有额外的reg_m)。 -
partial:部分 OT(partial OT)或带熵的部分 Wasserstein(常见于部分匹配场景)。
快速示例:如何创建实例
# 精确 OT,适合小 batch,可能比较慢
sampler = OTPlanSampler(method="exact", num_threads="max")# Sinkhorn,适合较大 batch,可调 reg
sampler = OTPlanSampler(method="sinkhorn", reg=0.1)# Schrödinger Bridge 场景常用 entropic OT:reg 应与 sigma 关联
sampler = OTPlanSampler(method="sinkhorn", reg=2 * sigma**2)
runner/
src/train.py
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:"Trains the model. Can additionally evaluate on a testset, using best weights obtained during training."datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)model: LightningModule = hydra.utils.instantiate(cfg.model)(datamodule=datamodule)callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)if cfg.get("train"):log.info("Starting training!")trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))train_metrics = trainer.callback_metricsif cfg.get("test"):log.info("Starting testing!")ckpt_path = trainer.checkpoint_callback.best_model_pathtrainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)test_metrics = trainer.callback_metrics# merge train and test metricsmetric_dict = {**train_metrics, **test_metrics} return metric_dict, object_dict
train.py 是一个 通用的训练入口。它本身不固定训练哪一个模型——它根据 configs(Hydra)来实例化:
- 一个
LightningDataModule(数据加载器/数据预处理), - 以及一个
LightningModule(模型 +loss+optimizer+training_step), - 然后用
pytorch_lightning.Trainer去fit(训练)和可选test(评估)。
所以要知道训练什么,需要看 configs 指向了哪个 LightningModule。在这个 repo(Flow Matching)里,常见的是训练 Conditional Flow Matching / Schrödinger Bridge / flow model 的 LightningModule:也就是训练一个网络去拟合速度场,来做生成或建模时间序列分布。
configs/**.yaml
Hydra 项目通常把行为都放在 configs/**.yaml。打开项目的 configs/train.yaml,在 train.yaml 里能看到 datamodule: ..., model: ..., trainer: ... 等。打开对应源码文件查看真正的训练逻辑(loss、forward)——真正训练的“是什么”通常在 LightningModule 里。
# runner/configs/train.yaml
defaults: # 顺序很重要:后面的会覆盖前面的同名字段(键冲突时以最后出现的为准)- _self_- datamodule: sklearn- model: cfm- callbacks: default- trainer: default- paths: default
datamodule: sklearn:默认将加载runner/configs/datamodule/sklearn.yaml(决定数据加载、预处理、batch size、train_val_test_split等)。model: cfm:默认使用runner/configs/model/cfm.yaml(指定要训练的模型类型、网络net,optimizer、partial_solver等)。
# runner/configs/model/cfm.yaml
_target_: src.models.cfm_module.CFMLitModuleoptimizer:_target_: torch.optim.AdamWlr: 0.001weight_decay: 1e-5net:_target_: src.models.components.simple_mlp.VelocityNethidden_dims: [64, 64, 64]batch_norm: Falseactivation: "selu"partial_solver:_target_: src.models.components.solver.FlowSolverode_solver: "euler"atol: 1e-5rtol: 1e-5
callbacks: default,trainer: default,paths: default等分别指定 callbacks、trainer、路径等配置文件。
src/models/cfm_module.py
runner/configs/model/cfm.yaml 中的 _target_: src.models.cfm_module.CFMLitModule 这一行的作用是当调用 model: LightningModule = hydra.utils.instantiate(cfg.model) 时,Hydra 会自动导入并实例化类 src.models.cfm_module.CFMLitModule。
也就是说,Hydra 会找到文件:
conditional-flow-matching/
└── src/└── models/└── cfm_module.py
并在其中找到类定义:
class CFMLitModule(LightningModule):"Conditional Flow Matching Module for training generative models and models over time."def __init__(self,net: Any,optimizer: Any,datamodule: LightningDataModule,augmentations: AugmentationModule,partial_solver: FlowSolver,scheduler: Optional[Any] = None,neural_ode: Optional[Any] = None,ot_sampler: Optional[Union[str, Any]] = None,sigma_min: float = 0.1,avg_size: int = -1,leaveout_timepoint: int = -1,test_nfe: int = 100,plot: bool = False,nice_name: str = "CFM",) -> None:super().__init__()self.net = net(dim=self.dim)self.partial_solver = partial_solverself.optimizer = optimizerself.scheduler = schedulerself.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)self.criterion = torch.nn.MSELoss()...
CFMLitModule ≈ 一个 Lightning 封装的 Conditional Flow Matching 模型
其中包含:
- 一个神经网络(通常是 UNet 或 MLP,用于预测速度场 uθu_\thetauθ);
- 一个损失函数(实现了 Flow Matching loss,用于逼近理论流场);
- 一个采样/积分器(例如 ODE/SDE solver,生成样本)。
| 属性 | 类型 | 作用 |
|---|---|---|
self.net | torch module | 预测速度场 dx/dt=f(t,x)dx/dt = f(t,x)dx/dt=f(t,x) |
self.aug_net | AugmentedVectorField | 对网络输出做增广,增强稳定性 |
self.datamodule | LightningDataModule | 提供数据、维度、轨迹信息 |
self.is_trajectory | bool | 数据是否为时间序列轨迹 |
self.dim | int/tuple | 数据维度(向量或图像) |
self.partial_solver | FlowSolver | 用于积分ODE/生成轨迹 |
self.ot_sampler | OTPlanSampler / None | 用于轨迹匹配和最优传输 |
self.criterion | torch.nn.MSELoss | 速度场回归损失 |
self.scheduler | optional | 学习率调度器 |
self.hparams.* | various | 超参数(sigma_min, avg_size, leaveout_timepoint 等) |
class CFMLitModule(LightningModule):def forward(self, t: torch.Tensor, x: torch.Tensor):"""Forward pass (t, x) -> dx/dt."""return self.net(t, x)def forward_integrate(self, batch: Any, t_span: torch.Tensor):"Forward pass with integration over t_span intervals. (t, x, t_span) -> [x_t_span]."X = self.unpack_batch(batch)X_start = X[:, t_span[0], :]traj = self.node.trajectory(X_start, t_span=t_span)return trajdef calc_u(self, x0, x1, x, t, mu_t, sigma_t):del x, t, mu_t, sigma_treturn x1 - x0def calc_loc_and_target(self, x0, x1, t, t_select, training):"""Computes the loss on a batch of data."""t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)eps_t = torch.randn_like(mu_t)x = mu_t + sigma_t * eps_tut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)...def step(self, batch: Any, training: bool = False):"""Computes the loss on a batch of data."""X = self.unpack_batch(batch)x0, x1, t_select = self.preprocess_batch(X, training)# Resample the plan if we are using optimal transportif self.ot_sampler is not None and not self.is_trajectory:x0, x1 = self.ot_sampler.sample_plan(x0, x1)x, ut, t, mu_t, sigma_t, eps_t = self.calc_loc_and_target(x0, x1, t, t_select, training)if self.hparams.avg_size > 0:x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)aug_x = self.aug_net(t, x, augmented_input=False)reg, vt = self.augmentations(aug_x)return torch.mean(reg), self.criterion(vt, ut)def training_step(self, batch: Any, batch_idx: int):reg, mse = self.step(batch, training=True)loss = mse + regprefix = "train"self.log_dict({f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},on_step=True,on_epoch=False,prog_bar=True,)return loss
-
forward()作用:给定时间 ttt 和状态 xxx,返回速度场 dx/dtdx/dtdx/dt。本质上就是 神经网络 f(t,x)f(t, x)f(t,x) 的前向传播。- 在 Conditional Flow Matching (CFM) 的公式里,这个网络就是学习目标向量场
ut(x∣x1)u_t(x|x_1)ut(x∣x1) 的近似:net(t,x)≈ut(x∣x1)\text{net}(t,x)≈u_t(x|x_1)net(t,x)≈ut(x∣x1); - 训练的时候,
step()会计算 ut=x1−x0u_t = x_1 - x_0ut=x1−x0,然后把网络输出和ut对比做 MSE loss。
所以
forward只是 网络本身的前向传播函数。 - 在 Conditional Flow Matching (CFM) 的公式里,这个网络就是学习目标向量场
-
forward_integrate()作用:根据初始状态X_start,沿着网络学习的速度场进行积分,生成整个时间轨迹。X_start = X[:, t_span[0], :]:取batch的初始状态(起点)。self.node.trajectory(X_start, t_span=t_span):使用神经ODE/ODE求解器沿t_span积分 dx/dt=f(t,x)dx/dt=f(t,x)dx/dt=f(t,x)。
返回完整轨迹
traj:shape 大约是(batch, len(t_span), dim)。在 CFM 或者 Neural ODE 中,这就是 生成数据/预测时间序列 的关键步骤:
- 训练阶段:用它生成预测轨迹对比真实轨迹计算 loss。
- 测试阶段:用它生成完整样本轨迹。
| 函数 | 功能 |
|---|---|
forward(t, x) | 网络前向传播,预测瞬时速度 dx/dtdx/dtdx/dt,用于训练 loss |
forward_integrate(batch, t_span) | 从起点沿网络预测速度场积分,生成完整轨迹,用于生成或测试 |
self.calc_loc_and_target() 里调用了 self.calc_u() 来得到 ut,然后用这个 ut 作为 训练目标。换句话说,训练目标 ut 就是 Conditional Flow Matching 的向量场。
conditional_flow_matching.py:算法核心,生成目标速度场。
CFMLitModule:训练框架,管理数据、优化器、Lightning接口、日志、评估。训练时
CFMLitModule可以用conditional_flow_matching.py的函数来计算训练目标。
但这里 CFMLitModule 实现了 Conditional Flow Matching 的核心思想,把公式内嵌在了 self.calc_u() 里,没有依赖 conditional_flow_matching.py 中的 compute_conditional_flow(x0, x1, t, x)。
src/models/runner.py
这里 runner.py 里的 CFMLitModule 直接调用 ConditionalFlowMatcher 来计算训练目标:
class CFMLitModule(LightningModule):def __init__(self,net: Any,optimizer: Any,datamodule: LightningDataModule,flow_matcher: ConditionalFlowMatcher,solver: FlowSolver,scheduler: Optional[Any] = None,plot: bool = False,) -> None:super().__init__()self.net = net(dim=self.dim)self.solver = solverself.optimizer = optimizerself.flow_matcher = flow_matcherself.scheduler = schedulerself.criterion = torch.nn.MSELoss()...def step(self, batch: Any, training: bool = False):"""Computes the loss on a batch of data."""x0, x1 = self.preprocess_batch(batch, training)t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x1)vt = self.net(t, xt)return torch.nn.functional.mse_loss(vt, ut)def training_step(self, batch: Any, batch_idx: int):loss = self.step(batch, training=True)self.log("train/loss", loss, on_step=True, prog_bar=True)return loss
这里 self.step() 中:
x0, x1是从 batch 里取出的数据(或噪声生成的初始点)。ut是 条件向量场 (conditional flow),也就是 网络需要学习的目标。vt = net(t, xt)是 网络的预测。loss = mse(vt, ut)是 训练目标。
所以 runner.py 的 CFMLitModule 是 训练器/封装器,而 ConditionalFlowMatcher 是 计算训练目标的工具。
| 特点 | conditional_flow_matching.py | runner.py |
|---|---|---|
| 前向网络 | net(t, x) | 同样 net(t, x) |
| 训练目标 | 通过 calc_u、average_ut 等手动计算 | 调用 flow_matcher.sample_location_and_conditional_flow 直接得到 |
| 积分/生成轨迹 | forward_integrate | forward_eval_integrate |
| 功能侧重点 | 全面,支持 trajectory / image / OT / leaveout_time | 简化版,直接和 ConditionalFlowMatcher 对接,更专注于训练流程 |
