影响6个时序Baselines模型的代码Bug

2024-09-18 16:57:01 浏览数 (1)

前言

我是从去年年底开始入门时间序列研究,但直到最近我读FITS这篇文章的代码时,才发现从去年12月25号就有人发现了数个时间序列Baseline的代码Bug。如果你已经知道这个Bug了,那可以忽略本文~

这个错误最初在Informer(AAAI 2021 最佳论文)中被发现,是爱丁堡大学的Luke Nicholas Darlow发现。这个问题对时间序列预测领域的一系列广泛研究都有影响,这个Bug影响了包括Patch TST、DLinear、Informer、Autoformer、Fedformer、FiLM在内的经典baseline。

  • PatchTST (ICLR 2023) - Link to affe‍cted code
  • DLinear (AAAI 2022 reported version) - Link to affected code
  • Informer (AAAI 2021 Best Paper) - Link to affected code
  • Autoformer (NIPS 2021 reported version) - Link to affected code
  • Fedformer (ICML 2022) - Link to affected code
  • FiLM (ICLR 2023) - Link to affected code

FITS这篇文章发布一个修复方法,以帮助社区在他们的工作中解决这个问题。参考链接:https://anonymous.4open.science/r/FITS/README.md

错误描述

这个错误源于数据加载器中的错误实现。测试数据加载器(test dataloader)使用了drop_last=True那么模型的评估可能会基于不完整的测试数据集,从而导致对模型性能的不准确评估,甚至可能导致不同模型之间比较的不公平。这个问题在使用较大批量大小时尤为明显,因为更大的批量大小更容易导致数据集大小不能被整除的情况。

注:在PyTorch等数据加载框架中,drop_last参数通常用于控制当数据集大小不能被批量大小整除时,是否丢弃最后一个不完整的批量。在训练过程中,为了保持每个epoch迭代次数的稳定性,通常会设置drop_last=True。然而,在测试或验证过程中,为了获得对模型性能的准确评估,应该确保所有测试数据都被使用,因此应该设置drop_last=False

解决方法

在data_factory.py 中,修改代码:

代码语言:javascript复制
if flag == 'test':
    shuffle_flag = False
    drop_last = True
    batch_size = args.batch_size
    freq = args.freq

如下:

代码语言:javascript复制
if flag == 'test':
    shuffle_flag = False
    drop_last = False #True
    batch_size = args.batch_size
    freq = args.freq

在代码 script 文件夹(e.g., ./exp/exp_main.py), 做出如下修改 (约在 290行),from

代码语言:javascript复制
preds = np.array(preds)
trues = np.array(trues)
inputx = np.array(inputx) # some times there is not this line, it does not matter

to:

代码语言:javascript复制
preds = np.concatenate(preds, axis=0)
trues = np.concatenate(trues, axis=0)
inputx = np.concatenate(inputx, axis=0) # if there is not that line, ignore this

作者说可以通过在维度0(即batch大小)上拼接(concatenate)剩余的数据解决问题,而不必丢弃最后一个不完整的batch。

结果更新

已发现的错误主要影响像ETTh1和ETTh2这样的小型数据集的结果。有趣的是,对于其他数据集,如ETTm1上的PatchTST等某些模型,却表现出了增强的性能。FITS(假设是指某个时间序列预测模型)仍然保持了足够好且与其他最先进模型相媲美的性能。

从更新后的结果我们发现,最能打还是Patch TST以及FITS。关于这两篇论文,我之前做过详细的解读,感兴趣可以关注阅读。

0 人点赞