交通预测旨在准确预测城市未来的交通流动模式,需要同时考虑时间与空间维度。但是,分布偏移现象是该领域的一个主要难题,因为现有模型在遇到与训练数据分布差异显著的测试数据时,往往难以实现有效的泛化。
为应对这一挑战,本文提出了一个简洁而通用的时空提示调整机制——FlashST,该机制旨在使预训练模型能够适应不同下游数据集的特定属性,从而提升其在多样化预测场景中的泛化性能。
具体而言,FlashST 框架通过一个轻量化的时空提示网络进行上下文信息的学习和提取,捕捉到时空中的恒定规律,并能够灵活适应各种不同的场景。此外,本研究还引入了一种分布映射技术,用以调整预训练数据和下游数据之间的分布差异,以促进在时空预测任务中的知识迁移。实验结果证实了 FlashST 在处理不同城市交通数据集时的有效性。
【论文标题】FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction
【论文地址】https://arxiv.org/abs/2405.17898
【论文源码】 https://github.com/HKUDS/FlashST
【实验室链接】https://sites.google.com/view/chaoh
论文概述
01、挑战
尽管现有的时空预测技术已经证明了它们的有效性,但大多数模型在面对不同下游数据集和任务时,由于分布变化,往往难以实现有效的泛化。在现实城市环境中,训练数据和测试数据之间的分布不一致性,成为了实现精确预测的障碍。如图1所示,如果将直接从数据集A学习到的参数应用于数据集B的测试,由于两个数据集在时空特征上的显著差异,可能会导致性能下降。因此,为了提高时空预测模型的泛化能力,需要有效地适应这种分布变化。设计具有适应性的方法面临以下挑战:
图1:FlashST动机:左图展示了不同交通数据集中数据分布的多样性,而右图显示了端到端模型的参数对训练集A过度拟合,未能泛化到测试集B
(1)有效提取时空上下文信息:从下游任务中准确提取复杂的时空上下文信息是关键所在。然而,要使预训练模型能够迅速理解并融合那些仅在测试阶段才能访问的新领域数据的空间和时间属性,是一个极具挑战性的任务。
(2)缩小训练与测试数据的分布差异:训练数据集与测试数据集之间往往存在显著的分布差异,这种情况在它们源自不同的时空背景和领域时尤为突出。设计模型适应框架,使其能够高效地缩小这种分布差异,并捕捉到时空中的恒定特征,对于提升模型的适应性至关重要。
02、贡献
(1)为了应对挑战1,本文提出了一种时空上下文信息提取方法,该方法能够捕捉到未见数据中的上下文信号,从而使得模型能够适应多样的时空环境。
(2)本文还引入了一种统一的分布映射策略,以增强 FlashST 框架。这一策略通过正则化提示嵌入,有效地缩小了预训练数据与下游任务之间的分布差异,从而促进了从预训练阶段到下游时空预测任务的知识有效迁移。
论文方法
图2: FlashST 模型框架
01、时空上下文学习
时空上下文学习框架通过一个专门的时空提示网络来实现,该网络由两个核心组件构成:(i)时空上下文提取器:这一机制能够高效地识别并捕获新数据中的时间与空间上下文信号。这种方法允许模型从数据的特定上下文中学习,进而有效地适应多样化的时空情境。(ii)时空依赖性分析器:该组件将时间和空间之间的复杂相互作用纳入网络结构中。通过精确捕捉和模拟这些依赖性,网络能够深刻理解不同时空要素之间的相互联系和影响。
(1)时空上下文蒸馏
(2)时空依赖建模
02、统一分布映射机制
为弥合预训练阶段与多样化下游任务中未见数据的分布差异,我们在 FlashST 框架中集成了一种分布映射策略。该策略旨在将预训练数据和下游数据映射到一个共同的分布空间。通过实现数据分布的一致性,促进了知识的无缝迁移,确保了预训练阶段获得的知识能够高效地应用于下游时空场景。
为达成上述目标,FlashST 利用标准化的提示嵌入来确保在各种不同的下游数据集中维持一致的分布特性。我们借鉴了对比学习领域的多项研究成果,特别是引入了基于 infoNCE 损失函数的方法来规范提示网络生成的表示。该损失函数的作用是拉近正样本对之间的表示距离,同时增加负样本对之间的表示差异。通过采用无需额外标注数据的自监督学习方法,优化 infoNCE 损失有助于生成更均衡的嵌入分布。
03、预训练和下游任务提示范式
在预训练阶段,我们利用专门的预训练数据集来训练并优化模型的所有参数。进入提示微调阶段时,我们通过在新的、未见过的数据集上进行有限的训练周期,专门调整提示网络的参数。这样的过程使得模型能够迅速适应新数据。我们提出的 FlashST 框架具有模型无关性,这意味着它可以与多种现有的时空预测模型无缝结合,作为下游任务的集成部分。
实验结果与分析
01、总体表现
(1)对比实验
对比实验的结果如下表,数据显示,与端到端的时空模型相比,我们提出的方法在多个城市数据预测场景中显示出了明显的优越性。这些结果强有力地证实了 FlashST 在精确捕捉城市数据中的复杂时空模式方面的有效性。我们提出的上下文学习框架在将这些知识迁移到新的下游任务上表现出了卓越的能力。通过有效管理分布差异,FlashST 成功地缩小了预训练模型与实际预测场景之间的语义差异。
表1:FlashST对比实验
(2)模型无关&模型微调
- 模型无关性优势。提出模型的一个关键优势在于其与模型无关的特性,这意味着它可以轻松地与多种现有的时空编码器结合,提供高度的灵活性,并避免了对特定模型选择的依赖。下表展示了 FlashST 方法与四种领先的时空模型(包括STGCN、GWN、MTGNN、PDFormer)的无缝集成能力。评估结果彰显了 FlashST 的多功能性,以及当与优秀的时空模型结合时,其性能的显著提升。成功地与先进模型集成,进一步增强了 FlashST 的适应性,以及其在多样化城市数据场景中提高预测准确性的潜力。
- 与全参数微调的对比。为了进一步证明框架的有效性,我们将FlashST的提示微调方法与全参数微调进行了对比。"w/o Finetune"指的是在预训练后直接对目标数据集进行预测,不进行任何形式的微调。而"w/ Finetune"则表示在预训练之后,采用全参数微调来适应目标数据。值得注意的是,与端到端的预测效果相比,全参数微调的结果可能未能充分利用预训练阶段的成果。在没有有效对齐预训练模型与下游任务的情况下,可能会引入噪声,导致错误的微调方向和次优的性能表现。
表2:模型无关&模型微调实验
02、模型效率评估
(1)训练时长
在本节中,我们对三种不同训练场景的时长进行了测量:端到端训练、全参数微调以及 FlashST 的模型效率评估,结果汇总在下表。对于端到端训练和全参数微调,我们按照现有基线的配置,设定了 100 个训练周期,并设定了 25 个周期的早停标准。
FlashST 的提示调整周期被限制在 20 个以内,以展示模型对新数据集的快速适应能力。结果显示,基于相同基线的端到端训练和全参数微调在效率上是相近的。两种设置之间的训练时长差异,主要是由于不同的初始化参数影响了模型的收敛速度。FlashST 框架显著提升了计算效率,将基线模型的训练时长缩短了 20% 至 80%,这极大地增强了模型适应新时空数据的能力。
表3:不同模型计算时间统计(秒)
快速收敛性。在本节中,我们探究了 FlashST 在不同数据集上实现收敛的速度。下图展示了在采用 PEMS07(M) 和 CA-D5 数据集,并将 MTGNN 作为下游模型时,验证误差的下降趋势。
观察结果表明,整合了 FlashST 方法后,下游模型能够在少数几个调整周期内迅速收敛。与此相比,端到端训练和微调方法则需要更多的训练周期来适应新的数据环境。这种快速收敛的特性得益于我们提出的时空提示网络和数据分布映射策略。这些策略使得模型能够结合已有的预训练知识,并利用新数据的时空特性,从而迅速适应各种不同的时空场景。
图3:FlashST收敛速度
03、消融实验
(1)时空上下文蒸馏的作用
我们分别进行了去除时间上下文信息(-TC)和空间上下文信息(-SC)的实验。实验结果显示,移除了时空上下文后大多数性能指标都出现了明显的下降。这一现象强调了在上下文学习过程中,保留时间和空间上下文信息的重要性。有效地编码时间特征和整合空间特征对于捕捉时空中的恒定模式以及加深模型对数据的理解极为关键。
(2)时空依赖建模的效用
我们进一步进行了去除时间编码器(-TE)和空间编码器(-SE)的实验。结果表明,时空依赖编码在上下文学习过程中,对于整合不同时间段和地点之间的复杂关系发挥了关键作用。包含时间与空间依赖编码器的模型能够更深刻地理解并利用时间与空间的复杂相互作用。这种能力显著提升了下游模型对新时空场景的快速适应性。
(3)统一分布映射机制的影响
我们从两个维度评估了统一分布映射策略的有效性:
- -Uni,不使用统一分布映射策略。性能的降低证实了该策略对模型性能的正面贡献。FlashST通过将不同的时空数据嵌入到统一的分布空间中,有效地缓解了预训练数据与新时空数据之间的分布差异。
- r/BN,将统一分布映射策略替换为批归一化。批归一化通过根据小批量数据的局部统计特性来标准化数据,这有助于解决神经网络训练中的内部协变量偏移问题,并提升模型的收敛速度。然而,由于没有建立起预训练数据与下游任务数据之间的联系,下游模型难以有效地从预训练中迁移知识。我们提出的分布映射策略确保了模型能够充分利用在预训练阶段获得的知识。通过校准不同数据源的分布,模型能够更快地适应新的时空环境,并做出更准确的预测。
图4:FlashST消融实验
04、超参分析
在本节中,我们探讨了不同超参数设置对模型性能的影响,特别是温度系数和损失权重系数的设置。我们的研究结果揭示了当温度系数设置为,损失权重系数设置为时,模型能够实现最优的性能表现。值得注意的是,这些超参数的微调对最终性能的影响并不显著,这显示了模型对不同参数配置具有很好的鲁棒性。即便在特征尺度不一致的情况下,模型也能有效地学习到区分不同区域嵌入特征的表示。此外,模型的性能不会因为统一性损失的增加而出现大幅波动,这表明我们的分布映射策略并没有对预测损失造成干扰。这些发现进一步证实了我们策略的有效性,并有助于下游模型快速适应新的时空环境。
图5:关于和的模型超参实验
05、案例研究
为了验证我们提出的统一分布映射方法在将不同数据表示映射到统一分布上的有效性,我们对应用了分布映射和未应用分布映射的提示嵌入进行了可视化对比。我们首先使用 PCA 技术将每个嵌入样本的高维特征降至二维,然后通过 L2 范数将这些降维后的嵌入投影到单位圆上,具体效果见下图。
图6:提示嵌入的分布可视化
可视化结果显示,统一分布映射策略成功地将提示嵌入转换为接近均匀的分布形态,提供了有力的证据支持。相较之下,未使用该策略的模型未能达到这种理想的分布特性。FlashST 通过将新时空环境的数据映射到统一的分布中,加强了利用预训练知识并快速适应新数据集的能力,进而在多种交通任务上取得了更好的表现。
总结展望
本文介绍了 FlashST 框架,旨在使时空预测模型能够快速适应各种未见过数据的下游任务。该框架采用了一个包含时空上下文提炼和依赖性建模的时空提示网络。通过捕捉关键的上下文信号和模拟时间与空间的复杂相互作用,框架能够灵活地适应多样化的时空环境。为应对分布差异问题,我们引入了一个分布映射机制,它能够调整预训练数据与下游数据的分布,从而促进了在时空预测任务中的知识迁移。通过一系列广泛的实验,我们证明了 FlashST 在多个时空预测领域的有效性和其强大的泛化能力。在未来的研究中,我们将继续探索将大型语言模型整合到 FlashST 框架中,以作为知识引导,进一步提升模型的适应性和预测性能。