prophet Diagnostics诊断

2021-01-14 14:40:47 浏览数 (2)

例子代码

https://github.com/lilihongjava/prophet_demo/tree/master/diagnostics

代码语言:javascript复制
# encoding: utf-8
import pandas as pd
from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation
from matplotlib import pyplot as plt
from fbprophet.diagnostics import performance_metrics
from fbprophet.plot import plot_cross_validation_metric


def main():
    df = pd.read_csv('./data/example_wp_log_peyton_manning.csv')
    m = Prophet()
    m.fit(df)
    future = m.make_future_dataframe(periods=366)
    df_cv = cross_validation(
        m, '365 days', initial='1825 days', period='365 days')
    cutoff = df_cv['cutoff'].unique()[0]
    df_cv = df_cv[df_cv['cutoff'].values == cutoff]

    fig = plt.figure(facecolor='w', figsize=(10, 6))
    ax = fig.add_subplot(111)
    ax.plot(m.history['ds'].values, m.history['y'], 'k.')
    ax.plot(df_cv['ds'].values, df_cv['yhat'], ls='-', c='#0072B2')
    ax.fill_between(df_cv['ds'].values, df_cv['yhat_lower'],
                    df_cv['yhat_upper'], color='#0072B2',
                    alpha=0.2)
    ax.axvline(x=pd.to_datetime(cutoff), c='gray', lw=4, alpha=0.5)
    ax.set_ylabel('y')
    ax.set_xlabel('ds')
    ax.text(x=pd.to_datetime('2010-01-01'), y=12, s='Initial', color='black',
            fontsize=16, fontweight='bold', alpha=0.8)
    ax.text(x=pd.to_datetime('2012-08-01'), y=12, s='Cutoff', color='black',
            fontsize=16, fontweight='bold', alpha=0.8)
    ax.axvline(x=pd.to_datetime(cutoff)   pd.Timedelta('365 days'), c='gray', lw=4,
               alpha=0.5, ls='--')
    ax.text(x=pd.to_datetime('2013-01-01'), y=6, s='Horizon', color='black',
            fontsize=16, fontweight='bold', alpha=0.8)
    fig.show()

    df_cv = cross_validation(m, initial='730 days', period='180 days', horizon='365 days')
    print(df_cv.head())

    df_p = performance_metrics(df_cv)
    print(df_p.head())

    fig = plot_cross_validation_metric(df_cv, metric='mape')
    fig.show()


if __name__ == "__main__":
    main()

Prophet包括时间序列交叉验证功能,使用历史数据测量预测误差。这是通过在历史数据中选择截止(cutoff)点来完成的,并且对于每个截止点,只使用该截止点之前的数据来拟合模型。然后我们可以将预测值与实际值进行比较。下图使用Peyton Manning数据集模拟历史数据预测,其中该模型拟合5年初始(initial)历史数据,并且在一年的时间范围内进行了预测。

prophet论文进一步描述了模拟的历史预测。

使用cross_validation函数可以针对一系列历史数据截止点自动完成此交叉验证过程。我们指定预测范围(horizon),然后指定可选的初始训练周期(initial)的大小和截止点日期之间的间隔(period)。默认情况下,初始训练周期(initial)设置为预测范围(horizon)的三倍,并且每半个预测范围一个截止点。

输出cross_validation是一个dataframe,其中包含每个模拟预测日期(ds)和每个截止日期(cutoff)的真实值y,预测值yhat。特别是,对cutoff和cutoff horizon之间的每个观察点进行预测。然后,这个dataframe可以用于计算yhat和y的误差度量。

在这里,我们进行交叉验证,以评估365天的预测表现,从训练数据第730天开始为第一个截止点,然后每180天进行一次预测。在这8年的时间序列中,这相当于11个总预测(训练数据是2007/12/10 - 2016/01/20,因为最后一个截止点也要预测365天,所有最后一个cutoff在2015-01-20,第一个cutoff为2010-02-15,2015-01-20减去2010-02-15=1800天,1800/180 1=11)。

代码语言:javascript复制
from fbprophet.diagnostics import cross_validation
df_cv = cross_validation(m, initial='730 days', period='180 days', horizon = '365 days')
df_cv.head()
代码语言:javascript复制
          ds      yhat  yhat_lower  yhat_upper         y     cutoff
0 2010-02-16  8.951414    8.427466    9.450795  8.242493 2010-02-15
1 2010-02-17  8.717693    8.224716    9.212075  8.008033 2010-02-15
2 2010-02-18  8.601236    8.052325    9.124939  8.045268 2010-02-15
3 2010-02-19  8.522942    8.031072    9.017550  7.928766 2010-02-15
4 2010-02-20  8.264680    7.798614    8.733420  7.745003 2010-02-15

在R语言中,参数units必须是as.difftime类型,即周或比这个时间更短的。在Python中,initial,period和horizon应当采用Pandas Timedelta格式的字符串,接受天或比这个时间更短的单位。

performance_metrics可以通过预测度量(yhat,yhat_lower,yhat_upper对比y)计算一些有用统计,作为距截止点距离(预测到未来有多远)的函数。计算的统计量为均方误差(MSE),均方根误差(RMSE),平均绝对误差(MAE),平均绝对误差(MAPE)以及yhat_lower和yhat_upper估计的覆盖范围。这些是在df_cv按预测范围horizon(ds减cutoff)排序后的预测滚动窗口上计算的。默认情况下,每个窗口中都会包含10%的预测,但可以使用rolling_window参数进行更改。

代码语言:javascript复制
from fbprophet.diagnostics import performance_metrics
df_p = performance_metrics(df_cv)
df_p.head()
代码语言:javascript复制
  horizon       mse      rmse       mae      mape  coverage
0 37 days  0.497400  0.705266  0.507702  0.058841  0.676565
1 38 days  0.503286  0.709427  0.512702  0.059420  0.675423
2 39 days  0.525588  0.724975  0.518825  0.060023  0.672682
3 40 days  0.532851  0.729967  0.521728  0.060334  0.673824
4 41 days  0.540234  0.735006  0.522736  0.060415  0.681361

交叉验证度量指标可以通过使用plot_cross_validation_metric显示,这里显示的是MAPE。下图的点表示df_cv为每个预测的绝对百分比误差。蓝线显示MAPE,其中平均值取自点的滚动窗口。通过下图可以看到,对于未来一个月的预测,误差约为5%(0.05),对于一年的预测,误差增加到11%(0.11)左右。

代码语言:javascript复制
# Python
from fbprophet.plot import plot_cross_validation_metric
fig = plot_cross_validation_metric(df_cv, metric='mape')

可以使用可选参数rolling_window更改图中滚动窗口的大小,该参数指定在每个滚动窗口中使用的预测比例。默认值为0.1,对应df_cv于每个窗口中包含的10%的行; 增加这将导致图中平均曲线更平滑。

initial期限应该足够长,以便捕获所有模型的组成部分,特别是seasonalities和额外的回归量:对于每年季节性至少为一年,对于每周季节性至少一周等。

0 人点赞