机器学习笔记 soft-DTW(论文笔记 A differentiable loss function for time-series)
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习笔记 soft-DTW(论文笔记 A differentiable loss function for time-series)相关的知识,希望对你有一定的参考价值。
1 soft-DTW来由
DTW 算法通过动态规划求解了两个序列的相似度。这个过程1是离散的,不可微的。如果要将其应用作为神经网络的损失函数,这是不行的。因为神经网络通过对损失函数结果进行梯度下降的方法,更新参数,要求损失函数可微。
2 符号说明
论文“A differentiable loss function for time-series”(2017 ICML)中使用了 Soft minimum 来代替 DTW minimum
对于两个序列和,我们定义代价矩阵,其中δ是 可微代价函数(某一时刻x上的p维信息+某一时刻y上的p维信息——>一个实数值)【通常δ(·,·)可以用欧几里得距离】
3 soft-DTW原理
定义集合,为路径上的代价和组成的集合(从(0,0)到(i,j)的最小开销路径的cost)
如果是DTW,那么它的动态规划式子为
如1所说,由于min是一个离散的过程,不可微,所以这导致了DTW的离散。
于是Soft-DTW使用了连续的soft-min
当γ=0的时候,就是DTW,否则他就是一个可微的式子
那么这里也是类似的
这里这篇论文做了一个近似
也就等于 了
3.1 前向传播
定义,这是一个集合,其中的每一个元素A是一个矩阵,该矩阵表示两个时间序列x和y之间的对齐矩阵(alignment matrix)
对于一个特定的对齐矩阵,A中只有在(1,1)到(n,m)路径上的点(i,j),其=1,其他点的都是0。
以DTW中出现过的图为例,那种情况下的A矩阵,在红色箭头上的(i,j),其=1,其余点的均为0DTW 笔记: Dynamic Time Warping 动态时间规整 (&DTW的python实现)_UQI-LIUWJ的博客-CSDN博客
换句话说,中包含了所有(1,1)到(n,m)的路径(每个路径是一个矩阵,每个矩阵只有路径上的元素为1)
于是矩阵内积<A,Δ(x,y)>表示这条路径下的代价和(非这条路径上的点乘0,这条路径上的点乘1,再求和)
于是,soft-dtw的目标函数为
3.1.1 算法伪代码
如果γ=0的时候,也就退化为了DTW,这里不同的是,我们需要关注γ>0的情况
3.2 反向传播
soft-DTW的目的是为了计算时间序列x和时间序列y之间的动态扭曲距离,y是目标序列的话,我们反向传播计算的是对时间序列x的梯度,也即
通过链式法则,我们有
这里的分子和分母都是矩阵,所以线性代数笔记:标量、向量、矩阵求导_UQI-LIUWJ的博客-CSDN博客
也就是在我们的问题中,都是一个p×m维矩阵,那么整体上是一个np×nm的矩阵(记🔺相对于x的雅可比矩阵)
对于第二项
由于
同样地根据链式法则有:
定义元素
我们令
所以有:
当为欧几里得距离的时候,对于任意n×m维度的矩阵,有:
3.2.1 反向传播的优化
对于这个式子,我们进行反向传播的时候,如果使用自动求导机制,那每一个的计算,都需要重新从开始计算,计算到为止,所以每一个都需要的时间复杂度,而每次反向传播都需要计算一次E矩阵,所以每次反向传播计算E就需要的时间复杂度
于是论文中给出了一种动态规划的方法计算E,将时间复杂度降低至
我们知道,而只会在(i,j+1),(i+1,j+1),(i+1,j)这三项中出现,所以也只有这三项会影响到
那么根据链式法则,有:
而根据soft-dtw的定义:
我们有:
(3.7)对两边求的偏导,有:
对(3-8)式两边取对数,再乘以γ,于是有:
同理我们有
所以我们可以从开始,逐个计算到,总的时间复杂度式O(mn)
伪代码如下
以上是关于机器学习笔记 soft-DTW(论文笔记 A differentiable loss function for time-series)的主要内容,如果未能解决你的问题,请参考以下文章
机器学习笔记: 聚类 模糊聚类与模糊层次聚类(论文笔记 Fuzzy Agglomerative Clustering :ICAISC 2015)