Social LSTM 实现代码分析
Posted sinoyou
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Social LSTM 实现代码分析相关的知识,希望对你有一定的参考价值。
Social LSTM最早提出于文献 “Social LSTM: Human Trajectory Prediction in Crowded Spaces”,但经过资料查阅目前暂未找到原文献作者所提供的程序代码和数据,而在github上有许多针对该文献的实现版本代码。
本文接下来的实现代码来自https://github.com/xuerenlv/social-lstm-tf,代码语言为Python3,代码大体实现了原论文中核心原创部分的模型,包括Vanilla LSTM(没有考虑行人轨迹之间关联性的LSTM)和Social LSTM(使用池化层考虑了行人轨迹之间关联性的LSTM模型)的模型构建、训练和小样本测试的代码,但对横向对比的其他模型、模型量化评估方法等暂未实现。
本文下面将从代码中矩阵数据和列表(list)数据的维度细说实现过程和模型的特点。
Vanilla LSTM 模型
训练数据
主要功能代码文件:util.py
数据格式:
input_data, target_data = dataLoader.next_batch()
# input_data : [batch_size, seq_length, 2]
# target_data : [batch_size, seq_length, 2]
批量处理数据大小 x 序列长度大小 x 二维地址数据(已经过标准化处理,介于\\(0 - 1\\))
数据解释:
- 模型在实际使用时,对于每个输入的位置数据(源于已知数据/上一步预测数据)
LSTM Cell
将该运行后得到的输出就可用于下一时刻位置的预测,因此从dataLoader
获得的input_data
和target_data
从数据维度上只在seq_length维度上有1个大小的错位,对于行人已知的\\(t_0 - t_obs\\)的轨迹,训练时参与损失函数计算的是网络预测的\\(t_1 - t_obs+1\\)轨迹。 - 同时,其在由于训练采用Minibatch,因此输入和目标数据的有大小为
batch_size
的第一维度。
模型中间变量
LSTM序列网络是模型的核心部分,输入数据需要修改结构以满足数据要求,同时序列网络的输出结果也需要经过处理才能够使用,为此,模型主要有以下中间变量:
inputs, embedding_inputs
inputs
是input_data
的拆分版,将其拆解为序列模型每步运行时的输入数据。
embedding_inputs
是将inputs
使用embedding层后得到的输入数据,默认满足embedding_size = rnn_size = 128
,因此数据可直接用于lstm的输入数据了。
# inputs : [N_0, N_1, N_2, ....], N_i = [batch_size, 2]
# embedding_inputs = [M_0, M_1, M_2, ....], M_i = [batch_size, embedding_size]
# embedding
embedding_w = tf.get_variables("embedding_w", [2, embedding_size])
embedding_b = tf.get_variables("embedding_b", [embedding_size])
for input in embedding_inputs:
x = tf.nn.relu(tf.nn.xw_plus_b(input, embedding_w, embedding_b))
embedding_inputs.append(x)
seq2seq.rnn_decoder
由于该源码相比tensorflow的版本更迭还是有一定的年代感,其在运行LSTM模型时使用了不常用的方法:
outputs, last_state = tf.contrib.legacy_seq2seq.rnn_decoder(embedded_inputs, self.initial_state, cell, loop_function=None, scope="rnnlm")
此LSTM模型严格来说并不是seq2seq模型,其只是借用了seq2seq中decoder相同的操作步骤用在这里(手动实现也不复杂),具体来说,就是在for
循环迭代embedded_inputs
列表中的元素,使LSTM的cell
运行对应的次数,而后将序列模型的每步运行输出生成outputs
列表,并返回最后一步运行的finial_state
。
output_w, output_b
LSTM模型输出的原始outputs
数据需经线性变换为合适结构才被进一步使用,在此是对于每个大小为rnn_size
的输出向量,线性变为为大小为\\(5\\)的结果向量,有关使用目的请参见下一节。
output_size = 5 # 具体赋值目的请参见下文与原文献
# output : [batch_size * seq_length, rnn_size]
output = tf.reshape(tf.concat(outputs, 1), [-1, rnn_size])
output_w = tf.get_variable("output_w", [rnn_size, output_size])
output_b = tf.get_variable("output_b", [output_size])
# output : [batch_size * seq_length, 5]
output = tf.nn.xw_plus_b(output, output_w, output_b)
*output数据中最终含有\\(batch\\_size * seq\\_length\\)个预测的位置(每个位置由5个参数表述),相同的reshape策略可确保output中预测位置与target中实际位置的排列顺序是相同的。
模型输出
将序列模型每步输出结果合并、线性变换和变形后得到output
,传入的target_data
经过变形后得到flat_target_data
:
# model.py
# output : [batch_size * seq_length, 5]
# flat_target_data : [batch_size * seq_length, 2]
output
和flag_target_data
就是最终用于(训练时)计算损失/(采样时,不依赖于target)计算下一时刻位置的数据。
两个变量的第一维度大小均为batch_size * seq_length
(在reshape策略相同情况下,第二维度数据在数据批次和时间点上一一对应),而两个变量在第二维度数据量的差异是:原文献中假设了LSTM Cell
输出的rnn_size
大小(默认为128)的结果满足二维高斯分布(bivariate Gaussian distribution),因此使用线性变换矩阵后得到的恰是刻画二维高斯分布的5个参数$\\mu_x, \\mu_y, \\sigma_x, \\sigma_y, \\rho $(有关如何基于二维高斯分布求出预测点和损失值请原文献的引用)。
Social LSTM模型
此部分暂时未完全整理出来,根据初步的代码阅读,Social LSTM与Vanilla LSTM整体的代码框架和模型构建方法是相似的,具体有下述几方面的差异:
batch_size
和max pedestrian number
,批量训练数据的差异:在Vanilla LSTM训练时,采用了Mini Batch的数据方式使每次模型迭代时具备一定的数量规模;而Social LSTM中由于池化层的加入使得同一时刻需要有MPN
个LSTM序列迭代,而纵使存在多个LSTM序列,其实共享的是同一个Cell,因此同一场景的多位行人的轨迹(在代码中称作frame)其实就可以等价于一个batch,从而使训练Cell时有一定的数据规模。# input_data format in vanilla lstm input_data = tf.placeholder(tf.float32, [None, seq_length, 2]) ---- # input_data format in social lstm input_data = tf.placeholder(tf.float32, [seq_length, maxNumPeds, 3])
social tensor
池化层:social LSTM结构从本质上就是vanilla lstm添加了池化层,在源代码的
grid.py
包含主要的social tensor
的支持方法。social tensor
在原文中用\\(H_i^t\\)表示,每个行人\\(i\\)在不同时间点\\(t\\)中都有不同的social tensor
。对于每个张量中的值,实际是由上一时刻其他行人的
Hidden State
加和得到,Hidden State
只有LSTM Cell真正跑起来才能得到,因此最终的social tensor
是在模型运算中所得到的(这也是为什么运算量较大的原因以至原文献中又提出了一种能够在运算前得到张量的O-LSTM模型),不过在模型运行前,Hidden State的加和方式就可以通过输入数据推算得出,grid.py
做得主要就是这部分工作,其生成了数据为01真值的Grid Mask
矩阵,在模型迭代时作为参数传入,从而简化生成social tensor
的过程。
以上是关于Social LSTM 实现代码分析的主要内容,如果未能解决你的问题,请参考以下文章
深度学习100例—— 利用pytorch长短期记忆网络LSTM实现股票预测分析 | 第5例
Python对商店数据进行lstm和xgboost销售量时间序列建模预测分析|附代码数据