只用全连接,也能搭建SOTA时间序列预测模型?

Posted fareise

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了只用全连接,也能搭建SOTA时间序列预测模型?相关的知识,希望对你有一定的参考价值。

‍微信公众号“圆圆的算法笔记”,持续更新NLP、CV、搜推广干货笔记和业内前沿工作解读~

后台回复“交流”加入“圆圆的算法笔记”交流群;回复“时间序列“、”多模态“、”迁移学习“、”NLP“、”图学习“等获取各个领域干货算法笔记~

时间序列预测的主流模型结构一直以来被RNN、CNN、Transformer三大模型主体主导。而Nbeats的出现,让纯全连接的模型结构在时间序列预测问题上也能取得非常好的效果。这篇文章梳理了Nbeats系列工作,从最基础的Nbeats版本,到可以引入外部变量的Nbeats版本,再到能够处理时空预测的Nbeats版本。

推荐阅读:

12篇顶会论文,深度学习时间序列预测经典方案汇总

Spatial-Temporal时间序列预测建模方法汇总

层次时间序列预测指南

如何建模时间序列的不确定性?

如何搭建适合时间序列预测的Transformer模型?

1. 基础版本Nbeats

论文题目:N-BEATS: NEURAL BASIS EXPANSION ANALYSIS FOR INTERPRETABLE TIME SERIES FORECASTING

下载地址:https://arxiv.org/pdf/1905.10437.pdf

Nbeats是Element AI发表于ICLR 2020的一篇工作,目前引用量200多万,在时间序列预测这个领域还是比较有影响力的。Nbeats开创了一个全新的时间序列预测backbone,仅通过全连接实现时间序列预测。Nbeats的核心思路是,通过多层全连接进行时间序列分解,每层拟合时间序列部分信息(之前层拟合的残差),有点类似于GBDT的啥思路。

Nbeats的具体模型结构如下图所示。整个模型包括多个stack,每个stack包括多个block,每个block是Nbeats的最基础结构模块,由多个全连接层组成。每个block包含两个主要部分,第一个部分将输入的时间序列映射成expansion coefficients,第二部分将expansion coefficients映射回时间序列。

什么是expansion coefficients呢?expansion coefficients可以理解为存储了时间序列内在的信息形成的一个低维向量。在模型实现上,其实就是一个向量映射过程:将输入的时间序列(维度为length)映射成低维向量(维度为dim),第二部分再将其映射回时间序列(长度为length)。这个步骤也类似于AutoEncoder,将时间序列映射成一个低维向量保存核心信息,再还原回来。假设每个block模块的输入为x,将其映射为expansion coefficients的过程可以表示为 

每个模块会生成两组expansion coefficients,一组用来预测未来(forecast),另一组用来预测过去(backcast)。这个过程可以表示为如下公式:

最终,每个block对输入的序列进行处理后,输出一个预测未来的序列,以及一个预测过去的序列。每个block的输入,是上一层block地输入减去上一层block的输出。通过这种方式,模型每层需要处理的是之前层无法正确拟合的残差,也起到了一个将时间序列进行逐层分解,每层预测时间序列一部分的作用。最终的预测结果,是各个block预测结果的加和。

为了能让模型的分解具有可解释性,文中也提出了在各个层引入一些先验知识,强制让某些层学习某种类型的时间序列特性,实现可解释的时间序列分解。实现的方法是通过约束expansion coefficients到输出序列的函数形式来实现。例如想让某层block主要预测时间序列的季节性,就可以用下面的公式强制输出是季节性的:

下图是作者用这种思路约束不同层学习不同信息的可视化结果,有的层学习了趋势性,有的层学习了周期性。

2. 引入外部变量的Nbeats

论文题目:Neural basis expansion analysis with exogenous variables:Forecasting electricity prices with NBEATSx

下载地址:https://arxiv.org/pdf/1905.10437.pdf

第一版本的Nbeats,输入只能是单一的时间序列,无法输入额外的特征。而在时间序列预测问题中,诸如日期信息、节日信息、属性信息等外部特征也是非常重要的。因此基于初版Nbeats,该团队又提出了可以引入外部特征的Nbeatsx,和初版Nbeats的主要区别是引入了外部特征X。

模型的主体结构和Nbeats基本一致,每个block除了输出序列外,还会输入外部特征,二者一起通过全连接层得到隐状态,再基于隐状态生成expansion coefficients。

此外,文中还提出了另一种引入time-dependent特征的方法:采用一个TCN(时间序列卷积模块)作为encodeer,对外部特征进行编码,将该编码作为生成预测结果的因素:

3. 用于时空预测的Nbeats

论文题目:GAGA: Fully Connected Gated Graph Architecture for Spatio-Temporal Traffic Forecasting

下载地址:https://arxiv.org/pdf/2007.15531.pdf

该团队提出的第三个版本的Nbeats将Nbeats扩展到了时空预测领域,能够处理存在空间关系的多个时间序列的建模。模型总共分为Graph Edge Weight、Time Gate、Graph Gate,以及和Nbeats类似的多block、stack嵌套的主体全连接结构,模型整体结构如下图所示。

Graph Edge Weight:将每个节点都表示成一个embeddiing,对于两个节点的关系,使用这两个节点对应的embedding内积表示,最终可以用一个矩阵W表示整个空间上两两节点之间的关系:

Time Gate Block(紫色部分):Time Gate模块主要用来对时间特征进行编码。这里会将时间特征和每个节点的embedding进行拼接后,通过全连接生成时间特征相关的表示。这个时间特征信息会通过乘法、除法的方式,实现将时间信息从原始序列中剥离(对应除法)再融合(对应乘法)的目的。Time Gate的输出对应两个Linear映射结果,一个用于从历史序列中剥离时间因素,一个用于最终的还原,因为历史序列和未来序列处于不同时间窗口,因此这里采用两套参数分别建模时间信息。Time Gate这种将时间因素从序列中剥离的方法,让模型能够更专注于非时间因素的纯序列建模,降低模型拟合难度。

Graph Gate Block(绿色部分):Graph Gate的目的是让每个节点都能融合整个空间中其他节点的信息。核心思路是,对于每一个节点,利用它和其他节点的Graph Edge Weight融合其他节点的时间序列信息。Graph Gate模块输出一个矩阵G,矩阵中每个元素的计算过程如下,相当于用节点i和节点j之间的距离对节点j的时间序列的每个时刻k加权:

由于使用了节点i的最大值进行减法和除法的处理后再过Relu激活函数,对于不相关的节点对,Relu激活函数可以起到过滤的作用,让这两个节点的输出结果为0。

上面生成的矩阵G、每个节点的embedding以及每个节点的序列,最终作为后续Nbeats模型的输入。Time Gate部分对输入进行信息分离,再在最终的产出结果上进行时间信息融合。

4. 总结

本文介绍了Nbeats的序列模型,包括最基础的Nbeats模型,以及在此基础上衍生出来的引入外部特征的Nbeats(Nbeatsx)和用于时空预测的Nbeats(GAGA)。Nbeats相比其他时间序列预测模型,独创了一种全部为全连接的backbone,核心思路是通过序列信息分解、分而治之的方法,实现准确的时间序列预测。

微信公众号“圆圆的算法笔记”,持续更新NLP、CV、搜推广干货笔记和业内前沿工作解读~

后台回复“交流”加入“圆圆的算法笔记”交流群;回复“时间序列“、”多模态“、”迁移学习“、”NLP“、”图学习“等获取各个领域干货算法笔记~

后台留言”交流“,加入圆圆算法交流群~

【历史干货算法笔记】

12篇顶会论文,深度学习时间序列预测经典方案汇总

如何搭建适合时间序列预测的Transformer模型?

Spatial-Temporal时间序列预测建模方法汇总

最新NLP Prompt代表工作梳理!ACL 2022 Prompt方向论文解析

图表示学习经典工作梳理——基础篇

一网打尽:14种预训练语言模型大汇总

Vision-Language多模态建模方法脉络梳理

花式Finetune方法大汇总

从ViT到Swin,10篇顶会论文看Transformer在CV领域的发展历程

以上是关于只用全连接,也能搭建SOTA时间序列预测模型?的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch搭建全连接网络训练MNIST数据集分类任务和气温预测回归任务及全连接网络过拟合和欠拟合的调参方式

清华&旷视让全连接层“内卷”,卷出MLP性能新高度

搭建三层全连接网络

神经网络--从0开始搭建全连接网络和CNN网络

将全连接层转换为 conv2d 并预测输出?

前馈全连接神经网络和函数逼近时间序列预测手写数字识别