阅读论文 smart pretrain,搭配MAE一起食用
安心定志
论文信息:
论文名称: Smart pretrain: model-agnostic and data-agostic representation learning for motion prediction
发表日期和会议:ICLR 2025
作者单位:商汤科技,港中文,多伦多大学,InnoHK CPII,上海AI lab
摘要:
预测周围智能体的未来运动对于自动驾驶车辆(AV)在动态、人机混合环境中安全运行至关重要。然而,大规模驾驶数据集的稀缺性阻碍了鲁棒且通用的运动预测模型的发展,限制了其捕捉复杂交互和道路几何结构的能力。受自然语言处理(NLP)和计算机视觉(CV)领域最新进展的启发,自监督学习(SSL)在运动预测社区中引起了广泛关注,用于学习丰富且可迁移的场景表示。尽管如此,现有的运动预测预训练方法大多专注于特定模型架构和单一数据集,限制了其可扩展性和通用性。为了解决这些挑战,我们提出了SmartPretrain,一种通用且可扩展的SSL框架,用于运动预测,该框架与模型和数据集无关。我们的方法结合了对比学习和重建式SSL,利用生成式和判别式范式的优势,有效表示时空演化和交互,而无需施加架构约束。此外,SmartPretrain采用了一种与数据集无关的场景采样策略,整合多个数据集,从而增强数据量、多样性和鲁棒性。在多个数据集上的广泛实验表明,SmartPretrain在不同数据集、数据划分和主要指标上始终提升了最先进的预测模型性能。例如,SmartPretrain显著将Forecast-MAE的MissRate降低了10.6%。这些结果突显了SmartPretrain作为统一、可扩展的运动预测解决方案的有效性,突破了小数据量机制的限制。代码可在https://github.com/youngzhou1999/SmartPretrain获取。
引言:
运动预测,即预测附近共享空间的智能体(例如车辆、骑行者、行人)的未来状态,对于自动驾驶系统在动态且人机混合环境中安全高效运行至关重要。上下文信息,包括周围智能体的状态和高精地图(HD地图),为运动行为提供了关键的几何和语义信息,因为智能体的行为高度依赖于与周围智能体及地图拓扑的交互。例如,智能体之间的互动线索(如让行)会影响其他智能体的决策,而车辆通常会在可行驶区域移动并遵循车道方向。因此,设计和学习能够捕捉丰富运动和上下文信息的场景表示一直是运动预测的核心挑战。
尽管自然语言处理(NLP)和计算机视觉(CV)领域已经多次展示了大规模数据集的强大力量,但由于收集和标注驾驶轨迹数据的成本高昂且耗时费力,现有的运动预测工作仍处于“小数据量模式”。例如,流行的运动预测数据集如Argoverse、Argoverse 2和Waymo Open Motion Dataset(WOMD)分别仅包含32万、25万和48万个数据序列,这远少于NLP和CV领域常用的数以亿计的数据规模(比如,在4000亿token的CommonCrawl数据集上训练的ChatGPT-3,以及在3.03亿张图像的JFT数据集上训练的Vision Transformer。轨迹数据的稀缺性限制了模型学习丰富且可迁移的场景表示的能力,从而制约了它们的表现和通用性。
为此,一些研究尝试通过利用基于先验知识和人工设计规则生成的合成轨迹数据来缓解真实驾驶运动数据的稀缺问题。然而,使用合成数据的一个主要缺点是“现实差距”——即人工生成的数据分布通常与现实世界的数据分布不同,导致合成训练与实际表现之间存在差距。此外,生成高保真的模拟需要大量的计算资源和精心设计,以便准确捕捉复杂的智能体交互,这也限制了其扩展性。
除了扩大数据规模外,NLP和CV领域在自监督学习(SSL)方面也取得了显著进展。BERT和Masked Autoencoders等模型的成功证明了通过在未标记数据上进行预训练可以获得富有表达力的表示/特征,从而在微调后增强下游任务的表现。因此,自监督学习最近在运动预测社区中获得了越来越多的关注。然而,设计真正通用且可扩展的SSL预训练策略对于运动预测而言并非易事。通过将现有工作归类为两大主要方法:生成式SSL和判别式SSL,我们可以详细探讨这些挑战。
在生成式自监督学习(SSL)领域,预训练任务旨在从真实数据中学习富含上下文信息的场景表示。为此,提出了多种与交通相关的预训练任务,例如操作行为分类或预测成功性、地图图关系预测以及交通事件检测。此外,为了复制在NLP和CV领域中Masked Autoencoder(MAE)技术的成功,研究者探索了多种针对交通场景的掩码策略。然而,这些预训练策略通常对模型架构施加了限制——它们首先提出预训练方法,而实现这些方法则需要模型生成特定类型的特征。例如,操作行为检测、图关系预测和交通事件检测要求模型显式地表达这些概念/特征。MAE方法要求每个轨迹或地图片段必须具有显式的特征表示,以支持重建式预训练。在仅对轨迹进行掩码和重建的方法中,尽管方法简单且灵活,但本质上仅充当一种数据增强技术。尽管许多工作专注于聚合智能体嵌入并提供对其的显式访问,但显式地图嵌入并不总是可用。因此,MAE中的地图重建预训练策略可能缺乏通用性。由于这种灵活性不足,这些SSL策略通常只能适用于特定模型,或者在大多数情况下仅适用于其单一特定模型。
在判别式自监督学习(SSL)领域,对比学习已成为运动预测任务中一项有前景的技术,其方法旨在通过对正负样本之间的轨迹和地图嵌入进行对比来学习判别性特征。然而,这种方法局限于基于栅格化地图表示的模型,而研究表明,与更近期基于Transformer或图的模型(这些模型引入了矢量化地图)相比,栅格化方法存在显著的性能差距。将对比学习应用于矢量化表示仍然是一个尚未充分探索的领域,因为矢量数据的不规则结构使得定义有意义的正负样本变得复杂,并需要复杂的采样策略。
总结来说,由于运动预测任务中输入表示的固有复杂性和多模态特性,现有的SSL预训练策略在适用性上存在不足,难以适应通用的模型架构和输入表示。此外,不同流行的运动预测数据集之间的差异进一步阻碍了这些方法利用多个数据集的能力,从而无法摆脱小数据量模式的限制。如图1所示,这些局限性削弱了自监督学习的真正潜力,其能够为不同模型架构、输入表示和数据来源提供统一且可扩展的表示学习策略。
在本文中,我们提出了 SmartPretrain,一种通用且可扩展的 SSL 表示学习方法,用于运动预测。为了充分释放 SSL 的潜力,SmartPretrain 专门设计为与模型无关且与数据集无关,从而提供了一个适用于不同模型架构和多样化数据来源的统一解决方案,摆脱了小数据量模式下次优表示学习的限制。我们的核心设计包括以下两个方面:
-
与模型无关的对比学习和重建式 SSL:SmartPretrain 提出了一种新颖的与模型无关的 SSL 框架,结合了生成式 SSL 和判别式 SSL 的优势。简而言之,我们同时重建轨迹并对来自不同智能体和时间窗口的轨迹嵌入进行对比,以学习场景中的时空演化和交互行为。这些专注于轨迹的 SSL 预训练任务避免了对模型架构和地图表示的限制,使 SmartPretrain 能够适应广泛的模型设计,无论是基于栅格、Transformer 还是图的模型。
-
与数据集无关的场景采样策略:我们提出了一种与数据集无关的场景采样策略,通过整合多个数据集实现有效扩展,尽管这些数据集之间存在固有差异。为此,我们标准化了数据表示,确保高质量的输入,并最大化数据量和多样性。这使得 SmartPretrain 能够利用更广泛的驾驶场景,增强泛化能力和鲁棒性。
通过对多个最先进的预测模型和多个数据集应用 SmartPretrain 的广泛实验表明,SmartPretrain 在下游数据集、数据划分和主要指标上一致提升了所有被评估模型的性能。例如,SmartPretrain 显著降低了 QCNet 的 minFDE、minADE 和 MR 分别达 4.9%、3.3% 和 7.6%,并降低了 Forecast-MAE 的对应指标达 4.5%、3.1% 和 10.6%。相较于现有的预训练方法,SmartPretrain 展现了卓越的性能。我们还进行了全面的消融研究,分析了流水线中每个组件的作用,例如多数据集扩展的效果以及两种提出的预训练任务的影响。据我们所知,SmartPretrain 是首个利用多个数据集进行 SSL 预训练的工作,可以应用于多种模型,用于驾驶领域的运动预测。
相关工作:
轨迹预测
传统的运动预测方法主要使用卡尔曼滤波,结合高精地图(HD-map)中的物理和操作先验来预测未来的运动状态,或者基于采样或优化的规划算法,通过手动指定或学习奖励函数生成未来轨迹。随着深度学习的快速发展,近期的研究开始采用数据驱动的方法进行运动预测。总体而言,这些方法可以分为三种不同的架构:
-
基于栅格图像和卷积神经网络(CNN)的方法:
栅格化方法利用CNN将场景上下文转换为鸟瞰图(Cui等,2019;Chai等,2019b)。 -
基于矢量化表示和图神经网络(GNN)的方法:
这些方法将场景中的每个实体表示为矢量,遵循Gao等(2020)提出的矢量化表示方式。 -
基于Transformer的方法:
Transformer架构在处理序列数据时表现出色,被广泛应用于运动预测任务中。
为了表示场景信息,基于栅格的方法将场景上下文栅格化为鸟瞰图,而其他两种架构则将场景中的每个实体表示为矢量。对于多模态轨迹输出,除了标准的单阶段预测流程(直接输出轨迹)外,目前还出现了两阶段预测优化流程。在两阶段流程中,首先生成粗略轨迹,然后对其进行细化以提高预测精度。
传统轨迹预测的参考文献:
[1]
轨迹预测中的自监督学习
自监督学习(SSL)方法已在视觉理解、自然语言处理以及多模态表示学习中得到了广泛应用。SSL旨在通过精心设计的预训练任务来学习信息丰富且通用的表示,这些表示可以通过监督微调用于下游任务。
最近,轨迹预测领域在整合SSL技术方面取得了显著进展。这些方法利用预训练任务对模型进行预训练,使其能够学习有价值的表示,并通过微调提升轨迹预测性能。现有的轨迹预测预训练流程可以分为三类:基于增强或合成数据的方法、对比学习方法。
-
基于合成数据的方法:
这些方法利用先验知识和人工设计规则生成轨迹和地图数据。例如,Li等(2024)通过地图增强模块和基于模型的规划模型生成驾驶场景。 -
对比学习方法:
对比学习方法通过对正负样本进行比较,对嵌入进行对齐和区分,从而捕捉高层次的语义关系。例如,PreTraM 是一种基于栅格的方法,它使用对比学习建模轨迹与地图之间以及地图之间的关系。 -
生成式掩码表示学习方法:
这些方法借鉴了Transformer和掩码自编码器(MAE)的成功经验,通过重建随机掩码的高精地图上下文和轨迹来学习标记间的关系。例如,Forecast-MAE将智能体的历史轨迹和车道段视为独立标记,在标记级别应用随机掩码,并将掩码后的标记输入Transformer主干网络进行重建。
尽管如此,现有的SSL流程存在两个主要局限性:
- 单数据集限制:由于不同数据集的格式差异,现有方法通常仅在一个轨迹预测数据集上进行预训练。
- 特定架构依赖:它们依赖于特定的模型架构和地图嵌入,限制了其对通用模型的适应性。此外,这些流程难以扩展到先进的基于图神经网络(GNN)的方法,因为这些方法将所有输入整合为统一的图结构,无法为预训练任务提供显式的地图嵌入。
为了解决这些挑战,我们提出了 SmartPretrain,一种与模型无关且与数据集无关的解决方案。它可以灵活应用于各种模型和数据集,不受模型架构或数据集格式的限制。
方法
问题定义——运动预测与自监督学习的结合
经典的轨迹预测任务可以定义如下:
在运动预测任务中,我们的目标是基于目标智能体在过去
T
h
T_h
Th 步的状态观测
s
h
=
[
s
−
T
h
+
1
,
s
−
T
h
+
2
,
.
.
.
,
s
0
]
∈
R
T
h
×
2
s_h = [s_{-T_h+1}, s_{-T_h+2}, ..., s_0] \in \mathbb{R}^{T_h \times 2}
sh=[s−Th+1,s−Th+2,...,s0]∈RTh×2,来预测其未来
T
f
T_f
Tf 步的状态
s
f
=
[
s
1
,
s
2
,
.
.
.
,
s
T
f
]
∈
R
T
f
×
2
s_f = [s_1, s_2, ..., s_{T_f}] \in \mathbb{R}^{T_f \times 2}
sf=[s1,s2,...,sTf]∈RTf×2 以及相关的概率
p
p
p。这里,每个状态通常包括位置(x, y坐标)等信息。自然地,目标智能体会与其上下文
c
c
c 发生交互,这些上下文包括周围智能体的观测状态及高精地图(HD Map)。典型的运动预测任务可以形式化为
(
s
f
,
p
)
=
f
(
s
h
,
c
)
(s_f, p) = f(s_h, c)
(sf,p)=f(sh,c),其中
f
f
f) 表示预测模型。
一般而言,运动预测模型采用编码器-解码器架构,并分为两个阶段:
-
编码阶段 z = f e n c ( s h , c ) z = f_{enc}(s_h, c) z=fenc(sh,c):
在这一阶段,编码器 f e n c f_{enc} fenc 将 s h s_h sh 和 c c c 嵌入并融合,以捕捉交通场景中的演变和交互,并生成轨迹嵌入 z z z。这一步骤对于理解目标智能体及其环境如何相互作用至关重要,它能够提取出有助于后续预测的关键特征。 -
解码阶段 ( s f , p ) = f d e c ( z ) (s_f, p) = f_{dec}(z) (sf,p)=fdec(z):
解码器 f d e c f_{dec} fdec 对嵌入 z z z 进行解析,生成多个可能的未来轨迹 s f s_f sf 及其对应的概率 p p p。由于未来的不确定性,一个有效的运动预测模型应当能够输出多条潜在轨迹及其发生概率,从而提供对未来可能行为的全面估计。
这种两阶段的方法不仅允许模型处理复杂的时空动态,还使得它可以适应不同的输入模态(例如矢量化地图、轨迹序列),并通过融合来自不同来源的信息提高预测准确性。此外,通过使用自监督学习(SSL)进行预训练,如SmartPretrain所展示的那样,可以进一步增强模型的学习能力,使其能够在未标注的数据上学习到更丰富、更具泛化性的表示。这种方法尤其对于解决小数据量问题和提升模型在多样化的实际应用场景中的表现具有重要意义。
SSL(自监督学习)通过对典型的运动预测学习方法进行转变,引入了一个自监督的预训练阶段。其目标是通过预文本任务对编码器
f
e
n
c
f_{enc}
fenc 进行预训练,以学习更具可迁移性和通用性的嵌入表示 z。随后,编码器
f
e
n
c
f_{enc}
fenc 和解码器
f
d
n
c
f_{dnc}
fdnc 会在实际的运动预测任务上进行联合微调。在相关研究中,挑战和关注点主要在于设计能够良好扩展并有效提升下游任务性能的预文本任务。
3.2 SMARTPRETRAIN SSL 框架
我们提出了 SmartPretrain,一种模型无关且数据集无关的预训练框架,用于运动预测。该框架可以灵活地应用于各种模型,无论其架构如何,并能够利用不同格式的数据集。如图 2 所示,SmartPretrain 由两部分组成:1)一种数据集无关的情景采样策略,用于构建具有代表性和多样性的预训练数据;2)一种模型无关的自监督学习(SSL)策略,包含两个以轨迹为核心的预文本任务——轨迹对比学习(TCL,Trajectory Contrastive Learning)和轨迹重建学习(TRL,Trajectory Reconstruction Learning),这些任务共同塑造了预训练过程,从而提升下游任务的性能。
3.2.1 数据集无关的采样
在数据集无关的采样过程中,我们的目标是为对比学习生成正样本对和负样本对,并为重建学习构建轨迹,同时确保跨数据集的数据一致性。
对比/重建自监督学习的数据采样
从构建对比学习中的正样本和负样本开始,一种直观的设计方法是对同一场景中不同代理(agents)的轨迹进行对比。然而,这种方法可能会导致次优性能,原因包括:1)过于强调代理之间的空间上下文学习,而缺乏充分的时间建模;2)正样本的多样性有限,因为它们包含相同的特征。为此,我们提出了一种时间采样策略,用于生成既能捕捉空间上下文又能反映代理时间演变的样本对。
具体而言,我们将多个数据集混合起来,形成一个综合数据池,并随机采样一个具有时间范围 T = T h + T f T = T_h + T_f T=Th+Tf 的场景。然后,我们在时间上采样两个子场景,这两个子场景具有相同的时间范围 T h T_h Th,但分别从不同的时间点 t t t 和 t ′ t' t′ 开始。为了防止子场景采样过程中信息泄露,我们确保单个场景的两个子场景不会重叠,因为重叠可能会影响预文本任务的训练。随后,这两个子场景将被用于构建来自同一代理在不同时段的正样本轨迹对,以及来自不同代理或不同时段的负样本轨迹对。
需要注意的是,除了用于对比学习外,从时间点 t t t开始的子场景还将用于重建学习,因为其时间范围被设计为与目标下游数据集的输入时间范围 T h T_h Th对齐。
维护数据集无关性
为了利用具有不同配置的多个数据集并实现数据扩展,我们引入了三项关键设计:
标准化表示
由于不同数据集的格式各异,高精地图(HD map)的分辨率和代理轨迹的时间范围往往存在显著差异,导致数据不一致。为了解决这一问题,我们将地图上下文和轨迹表示标准化为统一格式。具体而言,我们通过线性插值或降采样来固定高精地图的分辨率,以确保连续点之间的一致性。此外,为了处理场景时间范围的差异,我们对时间范围较短的场景应用零填充(zero-padding),使其长度与目标时间范围 T T T 对齐。在预训练过程中,填充的轨迹步长被标记为无效,以确保它们不会影响学习过程。
确保数据质量
不同数据集在数据质量和场景复杂性方面也存在差异,通常由于感知限制或代理进入/离开场景而导致轨迹不完整。为了提高一致性并确保更高的数据质量,我们在预训练期间仅包含完整的轨迹,并从训练流程中过滤掉不完整的轨迹。
最大化数据量和多样性
不同数据集中场景数量差异显著。我们发现,直接混合所有可用的训练数据,而不是平衡每个数据集中的轨迹数量,能够取得最佳的下游任务结果。
3.2.2 模型无关的对比学习与重建学习自监督策略
我们提出了一种模型无关的自监督学习(SSL)策略,该策略由两个预文本任务组成:1)轨迹对比学习任务(TCL),通过对代理和时间窗口之间的轨迹嵌入进行对比,丰富了学习到的轨迹表示;2)轨迹重建学习任务(TRL),更紧密地贴合运动预测的主要目标,更好地引导预训练的方向。需要注意的是,这两个预文本任务通过其轨迹聚焦设计确保了模型无关性:它们仅对比和重建代理的轨迹嵌入,而不是其他嵌入(如地图嵌入或定制化嵌入)。因此,它们消除了对模型架构和地图表示的限制,可以应用于更广泛的运动预测模型。
嵌入生成
对于采样的两个子场景,我们为其包含的所有轨迹生成嵌入。具体而言,我们遵循现有自监督学习文献中的自训练策略,以提升性能并避免模型崩溃。具体来说,两个子场景被输入到两个结构相同的分支中:一个在线分支(online branch)和一个动量分支(momentum branch)。在每个分支中,输入子场景首先通过运动预测模型生成所有轨迹的嵌入,然后这些嵌入被传递给一个投影器(projector)进行进一步编码调整。在预训练期间,在线分支会持续更新,而动量分支则通过指数移动平均机制(exponential moving average mechanism)偶尔更新。此外,在线分支的嵌入还会通过一个对比预测器(contrastive predictor)生成用于对比学习的最终嵌入。
轨迹对比学习(TCL)
TCL 的设计目的是通过在时空维度上对比轨迹嵌入来学习丰富的轨迹表示。具体而言,考虑到在一个小批量(mini-batch)中采样的场景包含
N
N
N 个代理,我们现在有两组嵌入
{
z
i
,
t
}
i
=
1
N
\{z_{i,t}\}_{i=1}^N
{zi,t}i=1N 和
{
z
j
,
t
′
′
}
j
=
1
N
\{z'_{j,t'}\}_{j=1}^N
{zj,t′′}j=1N,分别由在线分支和动量分支生成。我们定义对比损失函数,使得正样本彼此靠近,负样本彼此远离:
L
TCL
=
−
∑
i
=
1
N
log
exp
(
sim
(
z
i
,
t
,
z
i
,
t
′
′
)
/
τ
)
∑
j
=
1
N
exp
(
sim
(
z
i
,
t
,
z
j
,
t
′
′
)
/
τ
)
\mathcal{L}_{\text{TCL}} = -\sum_{i=1}^N \log \frac{\exp(\text{sim}(z_{i,t}, z'_{i,t'}) / \tau)}{\sum_{j=1}^N \exp(\text{sim}(z_{i,t}, z'_{j,t'}) / \tau)}
LTCL=−i=1∑Nlog∑j=1Nexp(sim(zi,t,zj,t′′)/τ)exp(sim(zi,t,zi,t′′)/τ)
其中,
sim
(
⋅
,
⋅
)
\text{sim}(\cdot, \cdot)
sim(⋅,⋅) 表示嵌入之间的相似度(例如余弦相似度),
τ
\tau
τ 是温度超参数,用于控制分布的锐利程度。
L c = − 1 N ∑ i = 1 N l o g e x p ( r ( z i , t , z i , t ′ ) / τ ) ∑ j = 1 , j ≠ i N e x p ( r ( z i , t , z j , t ) / τ ) + ∑ j = 1 N e x p ( z i , t , z j , t ′ ′ ) \mathcal{L}_c=-\frac{1}{N}\sum^{N}_{i=1}log\frac{exp(r(z_{i, t}, z^{'}_{i, t})/\tau)}{\sum^{N}_{j=1, j \neq i} exp(r(z_{i, t}, z_{j, t})/\tau) +\sum^{N}_{j=1}exp(z_{i, t}, z^{'}_{j, t^{'}})} Lc=−N1i=1∑Nlog∑j=1,j=iNexp(r(zi,t,zj,t)/τ)+∑j=1Nexp(zi,t,zj,t′′)exp(r(zi,t,zi,t′)/τ)
公式解释
在轨迹对比学习(TCL)中,公式中的符号定义如下:
- ** sim ( ⋅ , ⋅ ) \text{sim}(\cdot, \cdot) sim(⋅,⋅) **:表示余弦相似度(cosine similarity),用于衡量嵌入向量之间的相似性。
- τ \tau τ:温度超参数(temperature hyper-parameter),用于调整分布的锐利程度。
具体来说:
- 分子部分:
exp
(
sim
(
z
i
,
t
,
z
i
,
t
′
′
)
/
τ
)
\exp(\text{sim}(z_{i,t}, z'_{i,t'}) / \tau)
exp(sim(zi,t,zi,t′′)/τ) 表示正样本对
( z i , t , z i , t ′ ′ ) (z_{i,t}, z^{'}_{i,t'}) (zi,t,zi,t′′) 的相似性,即来自同一代理在不同时间线上的轨迹嵌入。 - 分母部分:包含两种负样本对:
- 内部排斥对 ( z i , t , z j , t ) (z_{i,t}, z_{j,t}) (zi,t,zj,t):来自同一子场景中其他代理的轨迹嵌入。
- 跨场景排斥对 ( z i , t , z j , t ′ ′ ) (z_{i,t}, z'_{j,t'}) (zi,t,zj,t′′):来自不同子场景的轨迹嵌入。
通过这种目标函数设计,最大化了同一代理嵌入的相似性,同时最小化了其他样本对的相似性,从而提升了模型捕捉有意义上下文关系和时间动态的能力。
轨迹重建学习(TRL)
尽管对比学习是一种判别性任务,有助于特征区分轨迹之间的运动和上下文差异,但运动预测的最终目标是一个回归任务。因此,仅通过对比学习学到的特征可能不一定与运动预测的需求对齐或对其有益。为此,我们提出了轨迹重建学习(TRL),旨在通过共享类似的训练目标,使学到的表示尽可能贴近运动预测的需求。
具体而言,回顾第 3.2.1 节,我们采样了一个时间范围为
T
T
T 的场景,并生成了时间为
t
t
t、时间范围为
T
h
T_h
Th 的子场景的轨迹嵌入
{
z
i
,
t
}
i
=
1
N
\{z_{i,t}\}_{i=1}^N
{zi,t}i=1N。然后,我们将这些嵌入传递给一个轨迹解码器,以重建从时间点
t
′
t'
t′ 开始的子场景的轨迹片段。重建损失
L
r
L_r
Lr 被简单地设计为重建轨迹与真实轨迹之间的平均 L1 距离:
L
r
=
1
N
∑
i
=
1
N
∥
x
^
i
,
t
′
−
x
i
,
t
′
∥
1
L_r = \frac{1}{N} \sum_{i=1}^N \| \hat{x}_{i,t'} - x_{i,t'} \|_1
Lr=N1i=1∑N∥x^i,t′−xi,t′∥1
其中:
- x ^ i , t ′ \hat{x}_{i,t'} x^i,t′ 是解码器重建的轨迹。
-
x
i
,
t
′
x_{i,t'}
xi,t′ 是真实的轨迹。
有趣的是,所提出的重建损失可以被视为下游轨迹预测任务回归损失的一种广义形式。当 t = 0 t = 0 t=0 时,输入段正好是下游轨迹预测任务的观测状态 s h s_h sh。任务基于观测状态 s h s_h sh,回归从未来状态 s f s_f sf开始的 T h T_h Th 步轨迹。
3.2.3 训练细节
预训练阶段
整体预训练方案结合了轨迹对比学习(TCL)和轨迹重建学习(TRL)。组合的损失函数如下:
L
pre-train
=
λ
TCL
L
TCL
+
λ
TRL
L
TRL
L_{\text{pre-train}} = \lambda_{\text{TCL}} L_{\text{TCL}} + \lambda_{\text{TRL}} L_{\text{TRL}}
Lpre-train=λTCLLTCL+λTRLLTRL
其中:
-
L
TCL
L_{\text{TCL}}
LTCL 是轨迹对比学习的损失。
-
L
TRL
L_{\text{TRL}}
LTRL 是轨迹重建学习的损失。
-
λ
TCL
\lambda_{\text{TCL}}
λTCL 和
λ
TRL
\lambda_{\text{TRL}}
λTRL是两个任务的权重超参数,用于平衡两者的贡献。
通过联合优化这两个任务,模型能够在预训练阶段同时学习到判别性和回归性特征,从而为下游任务提供更强的表示能力。
L
=
L
c
+
λ
L
r
L = L_c + \lambda L_r
L=Lc+λLr
任务平衡参数
λ
=
1.0
\lambda=1.0
λ=1.0
在微调阶段,完成特定模型的预训练后,我们将编码器 f e n c f_{enc} fenc 初始化为预训练得到的权重,并在整个下游轨迹预测任务上对整个模型进行微调。这个过程使用模型原有的预测目标和训练计划,以确保预训练期间学习到的通用表示能够适应具体的任务需求。
实验
4.1 实验设置
数据集
我们在三个大规模运动预测数据集上训练并评估我们的方法:Argoverse(Chang等,2019)、Argoverse 2(Wilson等,2023)和Waymo Open Motion Dataset (WOMD)(Sun等,2020)。每个数据集的具体情况如下:
- Argoverse:包含从互动且密集的交通中收集的33.3万个场景。每个场景提供高清地图和2秒的历史轨迹数据,用于预测接下来3秒内的轨迹,采样频率为10Hz。其训练集、验证集和测试集分别包含20.5万、3.9万和7.8万个场景。
- Argoverse 2:将历史和预测时间延长至5秒和6秒,同样以10Hz采样。数据被划分为20万、2.5万和2.5万个场景,分别用于训练、验证和测试。
- WOMD:提供1秒的历史轨迹数据,并预测未来8秒的轨迹,采样频率为10Hz。它包括48.7万个训练场景、4.4万个验证场景和4.4万个测试场景。
值得注意的是,预训练流程仅使用了这些数据集的训练部分。
基线模型
如前所述,我们的预训练流程可以无缝集成到大多数现有的轨迹预测方法中。在实验中,我们选择了四种流行且先进的方法作为预测骨干,以评估SmartPretrain如何进一步提升性能:
- HiVT(Zhou等,2022)
- HPNet(Tang等,2024)
- Forecast-MAE(Cheng等,2023)
- QCNet(Zhou等,2023)
我们使用它们官方开源的代码进行实现。
评估指标
根据官方数据集设置(Chang等,2019),我们使用标准的运动预测指标来评估模型性能,包括最小平均位移误差(minADE)、最小最终位移误差(minFDE)和错过率(MR)。尽管预测模型可以为每个智能体生成多达6条轨迹,但这些指标仅评估具有最小终点误差的轨迹,以此作为多模态预测最佳可能表现的一个信号。
实现细节
- 我们实现了2层带有批归一化的MLP作为投影器和对比预测器,而轨迹解码器则是一个带有层归一化的2层MLP,以更好地适应其序列特性。
- 使用AdamW优化在线分支。在动量分支中,运动预测编码器和投影器的权重初始化与在线分支相同,并通过指数移动平均(EMA)策略更新。我们使用的动量值为0.996,并通过余弦调度将其增加到1.0。
- 我们使用8个Nvidia A100 40GB GPU进行单数据集预训练,以及使用32个GPU进行数据规模预训练,每个阶段都训练128个epoch。对于微调,我们使用8个GPU,并遵循各模型原始的训练计划。
这种设置确保了SmartPretrain能够有效地利用大规模未标注数据中的信息,同时在特定任务上进行微调时也能达到优异的性能。通过这种方式,我们不仅提升了模型的泛化能力,还显著增强了其在实际应用中的表现。
4.2 定量结果
应用SmartPretrain到多个模型的性能
如表1所示,我们首先报告了在Argoverse和Argoverse 2的验证集和测试集上,将SmartPretrain应用于多个最先进的预测模型时的性能。我们考虑了两种预训练设置:仅在单一下游数据集上进行预训练,以及在所有三个数据集上进行预训练。具体来说,HiVT和QCNet使用这两种设置进行了预训练,而由于计算限制,HPNet和Forecast-MAE仅在单一下游数据集上进行了预训练。结果显示,SmartPretrain能够持续改进所有考虑的模型,在下游数据集、数据分割和主要指标上的表现都有所提升。例如,在验证集上,SmartPretrain可以显著降低QCNet的minFDE、minADE和MR,分别减少了4.9%、3.3%和7.6%。此外,与仅在一个数据集上进行预训练相比,使用所有数据集进行预训练也显示出了更为一致的改进。
与其他预训练方法的性能比较
我们也比较了SmartPretrain与其他预训练方法。为了公平且有意义的比较,我们寻找有开源代码的方法。据我们所知,在本文提交时,只有Forecast-MAE是开源的。具体来说,Forecast-MAE提出了一种运动预测主干网络和一种预训练策略。我们将SmartPretrain应用到Forecast-MAE的主干网络上,并在Argoverse 2数据集上将其与Forecast-MAE的预训练策略进行比较。如表2所示,我们的预训练方法显示出比Forecast-MAE的预训练方法更大的改进(例如,在minFDE上的改进为4.5%对比1.9%)。
这些结果表明,我们的预训练流程:1) 可以灵活地应用于广泛的运动预测模型;2) 通过预训练和数据扩展持续提高性能;3) 相较于现有的预训练方法提供了更强的性能增强。
4.3 消融研究
我们进行了全面的消融研究,以分析不同组件的影响,包括预训练的数据规模、提出的两种预训练任务及其相关的超参数和配置。为了高效评估,我们使用HiVT作为预测模型,并在Argoverse上进行预训练和微调,报告在Argoverse验证集上的性能。
预训练数据集的消融
SmartPretrain被设计成与数据集无关,因此可以利用多个数据集进行预训练。这里我们引入了四种预训练设置:
- 无预训练:模型直接在下游任务上从头开始训练,不进行预训练;
- 基线预训练:模型在单一数据集(即下游数据集)上进行预训练;
- 迁移预训练:模型在不同于下游数据集的单一数据集上进行预训练;
- 数据规模预训练:模型在多个数据集上进行预训练。
然后,我们将这些预训练设置应用于HiVT和QCNet,并在其原始下游数据集上进行微调/评估。通过这种方式,我们可以详细分析每种预训练设置对最终模型性能的影响。这种系统性的分析有助于理解如何最有效地利用预训练来提升模型的泛化能力和最终预测性能。
消融研究结果与分析
预训练数据集的消融结果
如表3所示,实验结果表明:
- 所有预训练策略均有效提升预测精度:相比于从头开始训练,预训练策略显著提高了模型性能,证明了利用预训练学习可迁移特征的有效性。
- 迁移预训练效果最弱:由于预训练数据集和下游任务数据集之间的分布不匹配,迁移预训练带来的性能提升最小。这种分布差异导致学到的表示与下游任务的相关性较低。
- 多数据集预训练效果最佳:通过利用多个数据集进行预训练(即数据规模预训练),模型能够从更多样化的训练数据中学习到更具泛化性和鲁棒性的特征,从而为下游任务带来最大的性能提升。
为了进一步验证SmartPretrain在利用额外预训练数据集方面的有效性,我们还对如何在预训练阶段使用这些额外数据集进行了消融实验。我们探索了两种设置:
- 在额外数据集上使用标准运动预测任务进行预训练,然后在下游目标数据集上进行微调。
- 直接将下游目标数据集和额外数据集结合,使用标准运动预测任务进行联合训练。
由于篇幅限制,相关结果和讨论详见附录A.1。
预训练任务的消融研究
回顾我们的SSL预训练框架,它包含两个预训练任务:轨迹对比学习(TCL)和轨迹重建学习(TRL)。表4展示了针对这两种任务的消融研究结果,我们将它们应用于Argoverse数据集上的HiVT模型。观察到以下现象:
- 单一任务的效果:当单独应用时,每个预训练任务都能提升模型性能。具体来说,TCL使minFDE降低了1.1%,而TRL使minFDE降低了1.8%。
- 组合任务的效果更优:当同时使用TCL和TRL时,模型性能得到了最大幅度的提升,minFDE降低了3%。这表明两种任务具有互补性,结合使用可以更全面地捕捉轨迹的时空演化和交互关系。
总结
上述消融研究表明:
- 多数据集预训练的重要性:通过整合多个数据集,SmartPretrain能够显著增强模型的泛化能力和鲁棒性,尤其是在面对多样化的驾驶场景时表现尤为突出。
- 预训练任务的设计关键性:轨迹对比学习(TCL)和轨迹重建学习(TRL)分别从不同角度提升了模型的表示能力,而两者的结合则进一步放大了这种效果,充分体现了多任务学习的优势。
这些分析不仅验证了SmartPretrain设计的有效性,也为未来的研究提供了有价值的参考方向,例如如何进一步优化数据集整合策略或设计更高效的预训练任务。
消融研究:预训练周期和批次大小
图3展示了在HiVT模型和Argoverse数据集上的关于预训练周期和批次大小的消融研究。观察到以下几点:
- 更大的预训练周期有助于模型进行更长时间的训练,使其能够学习到更加细致和复杂的特征表示,从而最终提升性能。
- 更大的批次大小对于对比学习特别有益,因为它可以在每个训练批次中提供更多的负样本。然而,随着周期和批次大小的进一步增加,收益递减的现象开始显现,表明预训练任务与微调任务之间存在一定的差距。这意味着超过某一阈值后,额外的预训练不会显著提高模型性能,并且应该平衡效率以避免不必要的开销。
重建策略
我们的预训练流程基于输入子场景t来重建其他子场景t’的轨迹。为了进行消融研究,我们探索了更多的重建目标,分为两组:
- 第一组包括预测历史信息,如输入子场景t的轨迹以及整个场景的轨迹。
- 第二组专注于预测性重建,排除任何历史数据,包括输入子场景t的互补轨迹步骤和其他子场景t’的轨迹。
表5显示,使用重建任务进行预训练可以增强预测性能,特别是预测性目标带来了最显著的改进。
可视化结果
我们在附录A.2中展示了一些微调后的可视化结果,这些结果被归类为四种不同的场景:轨迹对齐、长轨迹预测、新行为生成和平滑安全的轨迹合成。这些结果显示了模型在各种场景下的多模态轨迹预测能力,并突出了所提出方法增强的泛化能力和鲁棒性。此外,我们在附录A.3中展示了一些预训练中的重建轨迹,以说明我们的预训练任务学习的有效性。
结论
本文介绍了SmartPretrain,一个新颖的、与模型无关且与数据集无关的自监督学习(SSL)框架,旨在增强自动驾驶中的运动预测。通过结合对比和重建的SSL技术,SmartPretrain在多个数据集上持续提升了最新模型的性能。广泛的实验表明,将SmartPretrain应用于诸如HiVT、HPNet、Forecast-MAE和QCNet等多种预测模型时,能够带来显著的性能提升。此外,我们框架的灵活性允许在多个数据集上进行有效的预训练,利用数据多样性来提高准确性和泛化能力。SmartPretrain的表现优于现有方法,证实了我们方法的有效性。这些结果突显了SmartPretrain的可扩展性、多功能性及其在驾驶环境中推进运动预测的潜力。