机器学习笔记 soft-DTW(论文笔记 A differentiable loss function for time-series)

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习笔记 soft-DTW(论文笔记 A differentiable loss function for time-series)相关的知识,希望对你有一定的参考价值。

 1 soft-DTW来由

        DTW 算法通过动态规划求解了两个序列的相似度。这个过程1是离散的,不可微的。如果要将其应用作为神经网络的损失函数,这是不行的。因为神经网络通过对损失函数结果进行梯度下降的方法,更新参数,要求损失函数可微

2 符号说明 

论文“A differentiable loss function for time-series”(2017 ICML)中使用了 Soft minimum 来代替 DTW minimum

        对于两个序列,我们定义代价矩阵,其中δ是 可微代价函数(某一时刻x上的p维信息+某一时刻y上的p维信息——>一个实数值)【通常δ(·,·)可以用欧几里得距离】 

3 soft-DTW原理

        定义集合,为路径上的代价和组成的集合(从(0,0)到(i,j)的最小开销路径的cost)

         如果是DTW,那么它的动态规划式子为

         如1所说,由于min是一个离散的过程,不可微,所以这导致了DTW的离散。

        于是Soft-DTW使用了连续的soft-min

         当γ=0的时候,就是DTW,否则他就是一个可微的式子

(在max函数的平滑(log-sum-exp trick)_UQI-LIUWJ的博客-CSDN博客 中,我们知道

那么这里也是类似的  

                                  

                                  

这里这篇论文做了一个近似

                                 

  也就等于   了                     

 3.1 前向传播

        定义,这是一个集合,其中的每一个元素A是一个矩阵,该矩阵表示两个时间序列x和y之间的对齐矩阵(alignment matrix)

         对于一个特定的对齐矩阵,A中只有在(1,1)到(n,m)路径上的点(i,j),其=1,其他点的都是0。

          以DTW中出现过的图为例,那种情况下的A矩阵,在红色箭头上的(i,j),其=1,其余点的均为0DTW 笔记: Dynamic Time Warping 动态时间规整 (&DTW的python实现)_UQI-LIUWJ的博客-CSDN博客

         换句话说,中包含了所有(1,1)到(n,m)的路径(每个路径是一个矩阵,每个矩阵只有路径上的元素为1)

        于是矩阵内积<A,Δ(x,y)>表示这条路径下的代价和(非这条路径上的点乘0,这条路径上的点乘1,再求和)

        于是,soft-dtw的目标函数为

 3.1.1 算法伪代码

如果γ=0的时候,也就退化为了DTW,这里不同的是,我们需要关注γ>0的情况

 

3.2 反向传播

        soft-DTW的目的是为了计算时间序列x和时间序列y之间的动态扭曲距离,y是目标序列的话,我们反向传播计算的是对时间序列x的梯度,也即

         

        通过链式法则,我们有

        这里的分子和分母都是矩阵,所以线性代数笔记:标量、向量、矩阵求导_UQI-LIUWJ的博客-CSDN博客

        

 

也就是在我们的问题中,都是一个p×m维矩阵,那么整体上是一个np×nm的矩阵(记🔺相对于x的雅可比矩阵)

 

 对于第二项

由于

同样地根据链式法则有:

 定义元素

我们令 

 所以有:

为欧几里得距离的时候,对于任意n×m维度的矩阵,有:

 3.2.1 反向传播的优化

         对于这个式子,我们进行反向传播的时候,如果使用自动求导机制,那每一个的计算,都需要重新从开始计算,计算到为止,所以每一个都需要的时间复杂度,而每次反向传播都需要计算一次E矩阵,所以每次反向传播计算E就需要的时间复杂度

         于是论文中给出了一种动态规划的方法计算E,将时间复杂度降低至

         我们知道,而只会在(i,j+1),(i+1,j+1),(i+1,j)这三项中出现,所以也只有这三项会影响到

 那么根据链式法则,有:

 

 而根据soft-dtw的定义:

我们有:

 (3.7)对两边求的偏导,有:

 

 对(3-8)式两边取对数,再乘以γ,于是有:

  

同理我们有

 

 所以我们可以从开始,逐个计算到,总的时间复杂度式O(mn)

 伪代码如下

 

 

以上是关于机器学习笔记 soft-DTW(论文笔记 A differentiable loss function for time-series)的主要内容,如果未能解决你的问题,请参考以下文章

论文笔记/机器学习笔记:CBAM

论文笔记/机器学习笔记:CBAM

机器学习笔记: 聚类 模糊聚类与模糊层次聚类(论文笔记 Fuzzy Agglomerative Clustering :ICAISC 2015)

论文研读笔记——基于障碍函数的移动机器人编队控制安全强化学习

论文阅读笔记

机器学习笔记 - YOLOv7 论文简述与推理