当前位置: 首页 > news >正文

【代码】TorchCFM(Conditional Flow Matching library)代码入门

  • 代码结构速览
    • torchcfm/
      • conditional_flow_matching.py
        • Class ConditionalFlowMatcher
        • class ExactOptimalTransportConditionalFlowMatcher
        • class TargetConditionalFlowMatcher
        • class SchrodingerBridgeConditionalFlowMatcher
      • optimal_transport.py
    • runner/
      • src/train.py
      • configs/**.yaml
      • src/models/cfm_module.py
      • src/models/runner.py

paper:Improving and generalizing flow-based generative models with minibatch optimal transport
📦 GitHub 仓库: https://github.com/atong01/conditional-flow-matching

在这里插入图片描述

代码结构速览

conditional-flow-matching/
├── torchcfm/
│   ├── models/          	# 模型定义(神经网络架构)
│   │   └── unet     		# 主干网络
│   ├── conditional_flow_matching.py   	# Conditional Flow Matching 损失实现
│   ├── optimal_transport.py          	#  采样 
│   └── utils.py/           			# 辅助函数
├── examples/
│   ├── 2D_tutorials    # 二维
│   ├── images     		# 图像生成示例
│   └── ...
├── runner/
└── README.md

torchcfm/

conditional_flow_matching.py

Class ConditionalFlowMatcher

在 CFM(I-CFM)的基本形式中,将 zzz 定义为一对随机变量,源点 x0x_0x0 和目标点 x1x_1x1,并设置 q(z)=q(x0)q(x1)q(z) = q(x_0)q(x_1)q(z)=q(x0)q(x1) 为独立耦合。让条件为 x0x_0x0x1x_1x1 之间的高斯流,标准差为 σ\sigmaσ,定义为
pt(x∣z)=N(x∣tx1+(1−t)x0,σ2)(14)p_t(x|z) =\mathcal N \left(x | tx_1 + (1 − t)x_0, \sigma^2 \right) \tag{14}pt(xz)=N(xtx1+(1t)x0,σ2)(14)
ut(x∣z)=x1−x0(15)u_t(x|z) = x_1 − x_0 \tag{15}ut(xz)=x1x0(15)
ut(x∣z)u_t(x|z)ut(xz) 的公式源自将定理 2.1 应用于条件概率路径,其中 μt=tx1+(1−t)x0\mu_t = tx_1 + (1 − t)x_0μt=tx1+(1t)x0σt=σ\sigma_t = \sigmaσt=σ

class ConditionalFlowMatcher:"Base class for conditional flow matching methods."def __init__(self, sigma: Union[float, int] = 0.0):self.sigma = sigmadef compute_mu_t(self, x0, x1, t):"Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), x0 represents the source minibatch, x1 represents the target minibatch, see (Eq.14) [1]."t = pad_t_like_x(t, x0)return t * x1 + (1 - t) * x0def compute_sigma_t(self, t):"Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]."del t  # 基类直接忽略 t, 返回常数 self.sigmareturn self.sigmadef sample_xt(self, x0, x1, t, epsilon):"Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]."mu_t = self.compute_mu_t(x0, x1, t)sigma_t = self.compute_sigma_t(t)sigma_t = pad_t_like_x(sigma_t, x0)return mu_t + sigma_t * epsilondef compute_conditional_flow(self, x0, x1, t, xt):"ut : conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]."del t, xt  # 基类实现不使用 t xt, 该变体是假定向量场在任何 t 任何中间点都是恒定的return x1 - x0def sample_noise_like(self, x):return torch.randn_like(x)def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):"Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]."if t is None:t = torch.rand(x0.shape[0]).type_as(x0)assert len(t) == x0.shape[0], "t has to have batch size dimension"eps = self.sample_noise_like(x0)xt = self.sample_xt(x0, x1, t, eps)ut = self.compute_conditional_flow(x0, x1, t, xt)if return_noise:return t, xt, ut, epselse:return t, xt, utdef compute_lambda(self, t):"Compute the lambda: score weighting function, see Eq.(23) [3]"sigma_t = self.compute_sigma_t(t)return 2 * sigma_t / (self.sigma**2 + 1e-8)

采样 sample_xt() 函数对应论文中的公式 (14),代码使用了 重参数化技巧(Reparameterization Trick)。由于是从高斯分布采样:x∼N(μt,σ2I)x \sim \mathcal{N}(\mu_t, \sigma^2 I)xN(μt,σ2I),可以直接用重参数化方式表达为:
x=μt+σ⋅ε,ε∼N(0,I)x = \mu_t + \sigma \cdot \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I) x=μt+σε,εN(0,I)

这是深度生成模型中的标准做法(VAE、扩散模型等都使用),优点是:采样过程可导(允许梯度回传),易于 GPU 并行化。所以代码 return mu_t + sigma_t * epsilon

class ExactOptimalTransportConditionalFlowMatcher
class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):"Child class for optimal transport conditional flow matching method."def __init__(self, sigma: Union[float, int] = 0.0):super().__init__(sigma)self.ot_sampler = OTPlanSampler(method="exact")def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):"Compute the sample xt and the conditional vector field ut(x1|x0) with respect to the minibatch OT plan $\Pi$."x0, x1 = self.ot_sampler.sample_plan(x0, x1)  # 先用 OT sampler 对 (x0,x1) 重新配对return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)def guided_sample_location_and_conditional_flow(self, x0, x1, y0=None, y1=None, t=None, return_noise=False):"y0 represents the source label minibatch, y1 represents the target label minibatch."x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)if return_noise:t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1, epselse:t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1

覆写 sample_location_and_conditional_flow:先用 OT sampler 对 (x0,x1)(x_0,x_1)(x0,x1) 重新配对,再调用父类的 sample_location_and_conditional_flow。这样意味着生成的 (x0,x1)(x_0,x_1)(x0,x1) 对是按照 minibatch OT 进行 最优配对(不是简单随机配对),通常能提高训练效果。

class TargetConditionalFlowMatcher

Example II: Optimal Transport conditional VFs. 条件概率路径的一个可以说更自然的选择是将均值和标准差定义为随时间线性变化,即 μt(x)=tx1,and σt(x)=1−(1−σmin)t(20)\mu_t(x) = tx_1 , \ \text{and} \ \sigma_t(x) = 1 − (1 − \sigma_{\text{min}})t \tag{20}μt(x)=tx1, and σt(x)=1(1σmin)t(20)根据定理 3,该路径由 VF
ut(x∣x1)=x1−(1−σmin)x1−(1−σmin)t(21)u_t(x|x_1) = \frac{x_1 − (1 − \sigma_{\text{min}}) x}{ 1 − (1 − \sigma_{\text{min}})t} \tag{21}ut(xx1)=1(1σmin)tx1(1σmin)x(21) 生成。

class TargetConditionalFlowMatcher(ConditionalFlowMatcher):"2023 Lipman et al. style target OT conditional flow matching."def compute_mu_t(self, x0, x1, t):"Compute the mean of the probability path tx1, see (Eq.20) [2]"del x0t = pad_t_like_x(t, x1)return t * x1def compute_sigma_t(self, t):"Compute the standard deviation of the probability path N(t x1, 1 - (1 - sigma) t), see (Eq.20) [2]."return 1 - (1 - self.sigma) * tdef compute_conditional_flow(self, x0, x1, t, xt):"Compute the conditional vector field ut(x|x1) = (x1 - (1 - sigma) x)/(1 - (1 - sigma) t ), see Eq.(21) [2]. xt represents the samples drawn from probability path pt."del x0t = pad_t_like_x(t, x1)return (x1 - (1 - self.sigma) * xt) / (1 - (1 - self.sigma) * t)
class SchrodingerBridgeConditionalFlowMatcher

我们将条件路径分布设定为一个在 x0x_0x0x1x_1x1 之间、扩散尺度为 σ\sigmaσ布朗桥(Brownian Bridge)。该路径的概率分布与生成向量场定义如下:
pt(x∣z)=N(x∣tx1+(1−t)x0,t(1−t)σ2)(20)p_t(x \mid z) = \mathcal{N}\big(x \,\big|\, t x_1 + (1 - t)x_0,\; t(1 - t)\sigma^2 \big) \tag{20} pt(xz)=N(xtx1+(1t)x0,t(1t)σ2)(20)
ut(x∣z)=1−2t2t(1−t)(x−(tx1+(1−t)x0))+(x1−x0),(21)u_t(x \mid z) = \frac{1 - 2t}{2t(1 - t)} \big(x - (t x_1 + (1 - t)x_0)\big) + (x_1 - x_0), \tag{21} ut(xz)=2t(1t)12t(x(tx1+(1t)x0))+(x1x0),(21)
其中,utu_tut 由公式 (5) 计算,作为生成概率路径 pt(x∣z)p_t(x \mid z)pt(xz) 的向量场。边缘耦合(marginal coupling)π2σ2\pi_{2\sigma^2}π2σ2ut(x∣z)u_t(x \mid z)ut(xz) 一起定义了 ut(x)u_t(x)ut(x),后者通过算法 4 中的回归目标进行近似。

class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):"Child class for Schrödinger bridge conditional flow matching method."def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):if sigma <= 0:raise ValueError(f"Sigma must be strictly positive, got {sigma}.")elif sigma < 1e-3:warnings.warn("Small sigma values may lead to numerical instability.")super().__init__(sigma)self.ot_method = ot_methodself.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)def compute_sigma_t(self, t):"Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sqrt(t * (1 - t))*sigma^2), see (Eq.20) [1]."return self.sigma * torch.sqrt(t * (1 - t))def compute_conditional_flow(self, x0, x1, t, xt):"Compute the conditional vector field. ut(x1|x0) = (1 - 2 * t) / (2 * t * (1 - t)) * (xt - mu_t) + x1 - x0, see Eq.(21) [1]."t = pad_t_like_x(t, x0)mu_t = self.compute_mu_t(x0, x1, t)sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t) + 1e-8)ut = sigma_t_prime_over_sigma_t * (xt - mu_t) + x1 - x0return utdef sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):"Compute the sample xt and the conditional vector field ut(x1|x0) with respect to the minibatch entropic OT plan."x0, x1 = self.ot_sampler.sample_plan(x0, x1)return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)def guided_sample_location_and_conditional_flow(self, x0, x1, y0=None, y1=None, t=None, return_noise=False):x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)if return_noise:t, xt, ut, eps = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1, epselse:t, xt, ut = super().sample_location_and_conditional_flow(x0, x1, t, return_noise)return t, xt, ut, y0, y1

optimal_transport.py

class OTPlanSampler:"OTPlanSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost) with different implementations of the plan calculation."def __init__(self,method: str,reg: float = 0.05,reg_m: float = 1.0,normalize_cost: bool = False,num_threads: Union[int, str] = 1,warn: bool = True,) -> None:if method == "exact":self.ot_fn = partial(pot.emd, numThreads=num_threads)elif method == "sinkhorn":self.ot_fn = partial(pot.sinkhorn, reg=reg)elif method == "unbalanced":self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m)elif method == "partial":self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg)else:raise ValueError(f"Unknown method: {method}")self.reg = regself.reg_m = reg_mself.normalize_cost = normalize_costself.warn = warndef get_map(self, x0, x1):"Compute the OT plan (wrt squared Euclidean cost) between a source and a target minibatch."

method:指定使用哪种 OT 求解器,代码支持的值为 exactsinkhornunbalancedpartial

  • exact:精确地求解 Earth Mover’s Distance(通常基于线性规划 / Hungarian / network flow)。适合较小 batch,精度高但慢。

  • sinkhorn:基于熵正则化的 Sinkhorn 算法(可扩展到较大规模,需设置 reg)。

  • unbalanced:用于不守恒质量(mass)情形的 unbalanced Sinkhorn(有额外的 reg_m)。

  • partial:部分 OT(partial OT)或带熵的部分 Wasserstein(常见于部分匹配场景)。

快速示例:如何创建实例

# 精确 OT,适合小 batch,可能比较慢
sampler = OTPlanSampler(method="exact", num_threads="max")# Sinkhorn,适合较大 batch,可调 reg
sampler = OTPlanSampler(method="sinkhorn", reg=0.1)# Schrödinger Bridge 场景常用 entropic OT:reg 应与 sigma 关联
sampler = OTPlanSampler(method="sinkhorn", reg=2 * sigma**2)

runner/

src/train.py

@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:"Trains the model. Can additionally evaluate on a testset, using best weights obtained during training."datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)model: LightningModule = hydra.utils.instantiate(cfg.model)(datamodule=datamodule)callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)if cfg.get("train"):log.info("Starting training!")trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))train_metrics = trainer.callback_metricsif cfg.get("test"):log.info("Starting testing!")ckpt_path = trainer.checkpoint_callback.best_model_pathtrainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)test_metrics = trainer.callback_metrics# merge train and test metricsmetric_dict = {**train_metrics, **test_metrics}	return metric_dict, object_dict

train.py 是一个 通用的训练入口。它本身不固定训练哪一个模型——它根据 configs(Hydra)来实例化:

  • 一个 LightningDataModule(数据加载器/数据预处理),
  • 以及一个 LightningModule(模型 + loss + optimizer + training_step),
  • 然后用 pytorch_lightning.Trainerfit(训练)和可选 test(评估)。

所以要知道训练什么,需要看 configs 指向了哪个 LightningModule。在这个 repo(Flow Matching)里,常见的是训练 Conditional Flow Matching / Schrödinger Bridge / flow modelLightningModule:也就是训练一个网络去拟合速度场,来做生成或建模时间序列分布。

configs/**.yaml

Hydra 项目通常把行为都放在 configs/**.yaml。打开项目的 configs/train.yaml,在 train.yaml 里能看到 datamodule: ..., model: ..., trainer: ... 等。打开对应源码文件查看真正的训练逻辑(lossforward)——真正训练的“是什么”通常在 LightningModule 里。

# runner/configs/train.yaml
defaults:  # 顺序很重要:后面的会覆盖前面的同名字段(键冲突时以最后出现的为准)- _self_- datamodule: sklearn- model: cfm- callbacks: default- trainer: default- paths: default
  • datamodule: sklearn:默认将加载 runner/configs/datamodule/sklearn.yaml(决定数据加载、预处理、batch sizetrain_val_test_split 等)。
  • model: cfm:默认使用 runner/configs/model/cfm.yaml(指定要训练的模型类型、网络 netoptimizerpartial_solver 等)。
# runner/configs/model/cfm.yaml
_target_: src.models.cfm_module.CFMLitModuleoptimizer:_target_: torch.optim.AdamWlr: 0.001weight_decay: 1e-5net:_target_: src.models.components.simple_mlp.VelocityNethidden_dims: [64, 64, 64]batch_norm: Falseactivation: "selu"partial_solver:_target_: src.models.components.solver.FlowSolverode_solver: "euler"atol: 1e-5rtol: 1e-5
  • callbacks: default, trainer: default, paths: default 等分别指定 callbacks、trainer、路径等配置文件。

src/models/cfm_module.py

runner/configs/model/cfm.yaml 中的 _target_: src.models.cfm_module.CFMLitModule 这一行的作用是当调用 model: LightningModule = hydra.utils.instantiate(cfg.model) 时,Hydra 会自动导入并实例化类 src.models.cfm_module.CFMLitModule

也就是说,Hydra 会找到文件:

conditional-flow-matching/
└── src/└── models/└── cfm_module.py

并在其中找到类定义:

class CFMLitModule(LightningModule):"Conditional Flow Matching Module for training generative models and models over time."def __init__(self,net: Any,optimizer: Any,datamodule: LightningDataModule,augmentations: AugmentationModule,partial_solver: FlowSolver,scheduler: Optional[Any] = None,neural_ode: Optional[Any] = None,ot_sampler: Optional[Union[str, Any]] = None,sigma_min: float = 0.1,avg_size: int = -1,leaveout_timepoint: int = -1,test_nfe: int = 100,plot: bool = False,nice_name: str = "CFM",) -> None:super().__init__()self.net = net(dim=self.dim)self.partial_solver = partial_solverself.optimizer = optimizerself.scheduler = schedulerself.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)self.criterion = torch.nn.MSELoss()...

CFMLitModule ≈ 一个 Lightning 封装的 Conditional Flow Matching 模型

其中包含:

  • 一个神经网络(通常是 UNet 或 MLP,用于预测速度场 uθu_\thetauθ);
  • 一个损失函数(实现了 Flow Matching loss,用于逼近理论流场);
  • 一个采样/积分器(例如 ODE/SDE solver,生成样本)。
属性类型作用
self.nettorch module预测速度场 dx/dt=f(t,x)dx/dt = f(t,x)dx/dt=f(t,x)
self.aug_netAugmentedVectorField对网络输出做增广,增强稳定性
self.datamoduleLightningDataModule提供数据、维度、轨迹信息
self.is_trajectorybool数据是否为时间序列轨迹
self.dimint/tuple数据维度(向量或图像)
self.partial_solverFlowSolver用于积分ODE/生成轨迹
self.ot_samplerOTPlanSampler / None用于轨迹匹配和最优传输
self.criteriontorch.nn.MSELoss速度场回归损失
self.scheduleroptional学习率调度器
self.hparams.*various超参数(sigma_min, avg_size, leaveout_timepoint 等)
class CFMLitModule(LightningModule):def forward(self, t: torch.Tensor, x: torch.Tensor):"""Forward pass (t, x) -> dx/dt."""return self.net(t, x)def forward_integrate(self, batch: Any, t_span: torch.Tensor):"Forward pass with integration over t_span intervals. (t, x, t_span) -> [x_t_span]."X = self.unpack_batch(batch)X_start = X[:, t_span[0], :]traj = self.node.trajectory(X_start, t_span=t_span)return trajdef calc_u(self, x0, x1, x, t, mu_t, sigma_t):del x, t, mu_t, sigma_treturn x1 - x0def calc_loc_and_target(self, x0, x1, t, t_select, training):"""Computes the loss on a batch of data."""t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)eps_t = torch.randn_like(mu_t)x = mu_t + sigma_t * eps_tut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)...def step(self, batch: Any, training: bool = False):"""Computes the loss on a batch of data."""X = self.unpack_batch(batch)x0, x1, t_select = self.preprocess_batch(X, training)# Resample the plan if we are using optimal transportif self.ot_sampler is not None and not self.is_trajectory:x0, x1 = self.ot_sampler.sample_plan(x0, x1)x, ut, t, mu_t, sigma_t, eps_t = self.calc_loc_and_target(x0, x1, t, t_select, training)if self.hparams.avg_size > 0:x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)aug_x = self.aug_net(t, x, augmented_input=False)reg, vt = self.augmentations(aug_x)return torch.mean(reg), self.criterion(vt, ut)def training_step(self, batch: Any, batch_idx: int):reg, mse = self.step(batch, training=True)loss = mse + regprefix = "train"self.log_dict({f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},on_step=True,on_epoch=False,prog_bar=True,)return loss
  • forward() 作用:给定时间 ttt 和状态 xxx,返回速度场 dx/dtdx/dtdx/dt。本质上就是 神经网络 f(t,x)f(t, x)f(t,x) 的前向传播。

    • 在 Conditional Flow Matching (CFM) 的公式里,这个网络就是学习目标向量场
      ut(x∣x1)u_t(x|x_1)ut(xx1) 的近似
      net(t,x)≈ut(x∣x1)\text{net}(t,x)≈u_t(x|x_1)net(t,x)ut(xx1)
    • 训练的时候,step() 会计算 ut=x1−x0u_t = x_1 - x_0ut=x1x0,然后把网络输出和 ut 对比做 MSE loss

    所以 forward 只是 网络本身的前向传播函数

  • forward_integrate() 作用:根据初始状态 X_start,沿着网络学习的速度场进行积分,生成整个时间轨迹

    • X_start = X[:, t_span[0], :]:取 batch 的初始状态(起点)。
    • self.node.trajectory(X_start, t_span=t_span):使用神经ODE/ODE求解器沿 t_span 积分 dx/dt=f(t,x)dx/dt=f(t,x)dx/dt=f(t,x)

    返回完整轨迹 traj:shape 大约是 (batch, len(t_span), dim)

    在 CFM 或者 Neural ODE 中,这就是 生成数据/预测时间序列 的关键步骤:

    • 训练阶段:用它生成预测轨迹对比真实轨迹计算 loss。
    • 测试阶段:用它生成完整样本轨迹。
函数功能
forward(t, x)网络前向传播,预测瞬时速度 dx/dtdx/dtdx/dt,用于训练 loss
forward_integrate(batch, t_span)从起点沿网络预测速度场积分,生成完整轨迹,用于生成或测试

self.calc_loc_and_target() 里调用了 self.calc_u() 来得到 ut,然后用这个 ut 作为 训练目标。换句话说,训练目标 ut 就是 Conditional Flow Matching 的向量场。

conditional_flow_matching.py算法核心,生成目标速度场
CFMLitModule训练框架,管理数据、优化器、Lightning接口、日志、评估。

训练时 CFMLitModule 可以用 conditional_flow_matching.py 的函数来计算训练目标。

但这里 CFMLitModule 实现了 Conditional Flow Matching 的核心思想,把公式内嵌在了 self.calc_u() 里,没有依赖 conditional_flow_matching.py 中的 compute_conditional_flow(x0, x1, t, x)

src/models/runner.py

这里 runner.py 里的 CFMLitModule 直接调用 ConditionalFlowMatcher 来计算训练目标:

class CFMLitModule(LightningModule):def __init__(self,net: Any,optimizer: Any,datamodule: LightningDataModule,flow_matcher: ConditionalFlowMatcher,solver: FlowSolver,scheduler: Optional[Any] = None,plot: bool = False,) -> None:super().__init__()self.net = net(dim=self.dim)self.solver = solverself.optimizer = optimizerself.flow_matcher = flow_matcherself.scheduler = schedulerself.criterion = torch.nn.MSELoss()...def step(self, batch: Any, training: bool = False):"""Computes the loss on a batch of data."""x0, x1 = self.preprocess_batch(batch, training)t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x1)vt = self.net(t, xt)return torch.nn.functional.mse_loss(vt, ut)def training_step(self, batch: Any, batch_idx: int):loss = self.step(batch, training=True)self.log("train/loss", loss, on_step=True, prog_bar=True)return loss

这里 self.step() 中:

  • x0, x1 是从 batch 里取出的数据(或噪声生成的初始点)。
  • ut条件向量场 (conditional flow),也就是 网络需要学习的目标
  • vt = net(t, xt)网络的预测
  • loss = mse(vt, ut)训练目标

所以 runner.pyCFMLitModule训练器/封装器,而 ConditionalFlowMatcher计算训练目标的工具。

特点conditional_flow_matching.pyrunner.py
前向网络net(t, x)同样 net(t, x)
训练目标通过 calc_uaverage_ut 等手动计算调用 flow_matcher.sample_location_and_conditional_flow 直接得到
积分/生成轨迹forward_integrateforward_eval_integrate
功能侧重点全面,支持 trajectory / image / OT / leaveout_time简化版,直接和 ConditionalFlowMatcher 对接,更专注于训练流程
http://www.dtcms.com/a/600846.html

相关文章:

  • C++主流日志库深度剖析:从原理到选型的全维度指南
  • CAD/CASS 无法复制到剪贴板
  • C语言在线编译环境 | 轻松学习C语言编程,随时随地在线编程
  • C语言在线编译器开发 | 提供高效、易用的在线编程平台
  • 东莞专业做网站的公司有哪些安徽建设工程信息网技术服务电话
  • 【前端面试】Git篇
  • Oracle RAC 再遇 MTU 坑:cssd 无法启动!
  • 用asp做网站怎么布局t型布局网站的优缺点
  • OpenGL lookAt 函数 参数说明
  • 【刷题笔记】 AOV网的拓扑排序
  • 3D TOF 视觉相机:以毫秒级三维感知,开启智能交互新时代
  • 快速配置 HBase 完全分布式(依赖已部署的 Hadoop+ZooKeeper)
  • 深圳网站搜索排名产品软文范例软文
  • 手机网站关键词seo网站 模板 html
  • 多模态工程师面试--准备
  • 安全迁移Windows个人文件夹至非C盘:分步教程与避坑指南
  • 多智能体框架AgentScope 1.0 深度技术剖析:架构、场景、选型与实战指南
  • flinkcdc抽取postgres数据
  • SpringCloud Gateway缓存body参数引发的问题
  • Qt跨平台:Linux与Windows
  • 【数据集分享】汽车价格预测数据集
  • 汽车网络安全综合参考架构
  • 亚远景-ISO 26262与ISO 21434:未来汽车安全标准的发展趋势
  • Leverege 携手谷歌云和BigQuery,赋能大规模企业级物联网(IoT)解决方案
  • 国外网站服务器免费网站被做跳转
  • 分享一个我自用的 Python 消息发送模块,支持邮件、钉钉、企业微信
  • 南昌商城网站建设网页设计作业文件
  • 物联网传感器数据漂移自适应补偿与精度动态校正技术
  • docker 按带ssh的python环境的容器
  • 基于深度随机森林(Deep Forest)的分类算法实现