基于“Channel LSTM”的基于 LSTM 的 EEG 信号分类架构
Posted
技术标签:
【中文标题】基于“Channel LSTM”的基于 LSTM 的 EEG 信号分类架构【英文标题】:LSTM-based architecture for EEG signal Classification based-on "Channel LSTM" 【发布时间】:2021-04-25 07:57:17 【问题描述】:我有一个多类分类问题,我在 python 3.6 中使用了 keras 和 tensorflow。我有一个很好的分类实现,基于本文中提到的“stacked LSTM 层 (a)”:Deep Learning Human Mind for Automated Visual Classification。
类似这样的事情:
model.add(LSTM(256,input_shape=(32, 15360), return_sequences=True))
model.add(LSTM(128), return_sequences=True)
model.add(LSTM(64), return_sequences=False)
model.add(Dense(6, activation='softmax'))
让 32 是脑电图通道数,15360 是 96 秒记录中 160 Hz 的信号长度
我想实现上面文章中提到的“Channel LSTM和Common LSTM (b)”策略,但我不知道我应该如何通过这个新策略来制作我的模型。 p>
请帮助我。谢谢
【问题讨论】:
嗨!你都尝试了些什么?你能添加你的代码和你的问题的详细描述吗? 根据@'Mr.对于Example的答案,请等到我得出正确可靠的结果,以便我可以与您分享代码和详细信息。 【参考方案1】:首先,您在使用 Common LSTM 实现编码器时遇到问题,LSTM layer of keras 默认采用形状为(batch, timesteps, channel)
的输入,因此如果您设置input_shape=(32, 15360)
,则模型将读取为 timesteps=32
和 channel=15360
,这与您的意图相反。
因为第一层编码器使用Common LSTM描述为:
在每个时间步 t,第一层取输入 s(·, t)(在这个 从某种意义上说,“共同”意味着所有 EEG 通道最初都被馈入 相同的 LSTM 层)
因此,使用 Common LSTM 的编码器的正确实现是:
import tensorflow as tf
from tensorflow.keras import layers, models
timesteps = 15360
channels_num = 32
model = models.Sequential()
model.add(layers.LSTM(256,input_shape=(timesteps, channels_num), return_sequences=True))
model.add(layers.LSTM(128, return_sequences=True))
model.add(layers.LSTM(64, return_sequences=False))
model.add(layers.Dense(6, activation='softmax'))
model.summary()
哪些输出(PS:你可以总结一下你的原始实现,你会看到Total params
更大):
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 15360, 256) 295936
_________________________________________________________________
lstm_1 (LSTM) (None, 15360, 128) 197120
_________________________________________________________________
lstm_2 (LSTM) (None, 64) 49408
_________________________________________________________________
dense (Dense) (None, 6) 390
=================================================================
Total params: 542,854
Trainable params: 542,854
Non-trainable params: 0
_________________________________________________________________
其次,因为编码器使用Channel LSTM和Common LSTM描述为:
第一个编码层由几个 LSTM 组成,每个 LSTM 连接到 只有一个输入通道:例如,第一个 LSTM 处理输入 datas(1,·),第二个 LSTM 进程s(2,·),依此类推。这样, 每个“通道 LSTM”的输出是单个通道的摘要 数据。然后第二个编码层执行通道间分析, 通过接收所有通道的连接输出向量作为输入 LSTM。如上,最深的 LSTM 在最后一个时间步的输出 用作编码器的输出向量。
由于第一层的每个 LSTM 只处理一个通道,所以我们需要 LSTM 的数量等于第一层的通道数,下面的代码展示了如何使用 Channel LSTM 和 Common 构建一个编码器LSTM:
import tensorflow as tf
from tensorflow.keras import layers, models
timesteps = 15360
channels_num = 32
first_layer_inputs = []
second_layer_inputs = []
for i in range(channels_num):
l_input = layers.Input(shape=(timesteps, 1))
first_layer_inputs.append(l_input)
l_output = layers.LSTM(1, return_sequences=True)(l_input)
second_layer_inputs.append(l_output)
x = layers.Concatenate()(second_layer_inputs)
x = layers.LSTM(128, return_sequences=True)(x)
x = layers.LSTM(64, return_sequences=False)(x)
outputs = layers.Dense(6, activation='softmax')(x)
model = models.Model(inputs=first_layer_inputs, outputs=outputs)
model.summary()
输出:
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_3 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_4 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_5 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_6 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_7 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_8 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_9 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_10 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_11 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_12 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_13 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_14 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_15 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_16 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_17 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_18 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_19 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_20 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_21 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_22 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_23 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_24 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_25 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_26 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_27 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_28 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_29 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_30 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_31 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
input_32 (InputLayer) [(None, 15360, 1)] 0
__________________________________________________________________________________________________
lstm (LSTM) (None, 15360, 1) 12 input_1[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM) (None, 15360, 1) 12 input_2[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM) (None, 15360, 1) 12 input_3[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM) (None, 15360, 1) 12 input_4[0][0]
__________________________________________________________________________________________________
lstm_4 (LSTM) (None, 15360, 1) 12 input_5[0][0]
__________________________________________________________________________________________________
lstm_5 (LSTM) (None, 15360, 1) 12 input_6[0][0]
__________________________________________________________________________________________________
lstm_6 (LSTM) (None, 15360, 1) 12 input_7[0][0]
__________________________________________________________________________________________________
lstm_7 (LSTM) (None, 15360, 1) 12 input_8[0][0]
__________________________________________________________________________________________________
lstm_8 (LSTM) (None, 15360, 1) 12 input_9[0][0]
__________________________________________________________________________________________________
lstm_9 (LSTM) (None, 15360, 1) 12 input_10[0][0]
__________________________________________________________________________________________________
lstm_10 (LSTM) (None, 15360, 1) 12 input_11[0][0]
__________________________________________________________________________________________________
lstm_11 (LSTM) (None, 15360, 1) 12 input_12[0][0]
__________________________________________________________________________________________________
lstm_12 (LSTM) (None, 15360, 1) 12 input_13[0][0]
__________________________________________________________________________________________________
lstm_13 (LSTM) (None, 15360, 1) 12 input_14[0][0]
__________________________________________________________________________________________________
lstm_14 (LSTM) (None, 15360, 1) 12 input_15[0][0]
__________________________________________________________________________________________________
lstm_15 (LSTM) (None, 15360, 1) 12 input_16[0][0]
__________________________________________________________________________________________________
lstm_16 (LSTM) (None, 15360, 1) 12 input_17[0][0]
__________________________________________________________________________________________________
lstm_17 (LSTM) (None, 15360, 1) 12 input_18[0][0]
__________________________________________________________________________________________________
lstm_18 (LSTM) (None, 15360, 1) 12 input_19[0][0]
__________________________________________________________________________________________________
lstm_19 (LSTM) (None, 15360, 1) 12 input_20[0][0]
__________________________________________________________________________________________________
lstm_20 (LSTM) (None, 15360, 1) 12 input_21[0][0]
__________________________________________________________________________________________________
lstm_21 (LSTM) (None, 15360, 1) 12 input_22[0][0]
__________________________________________________________________________________________________
lstm_22 (LSTM) (None, 15360, 1) 12 input_23[0][0]
__________________________________________________________________________________________________
lstm_23 (LSTM) (None, 15360, 1) 12 input_24[0][0]
__________________________________________________________________________________________________
lstm_24 (LSTM) (None, 15360, 1) 12 input_25[0][0]
__________________________________________________________________________________________________
lstm_25 (LSTM) (None, 15360, 1) 12 input_26[0][0]
__________________________________________________________________________________________________
lstm_26 (LSTM) (None, 15360, 1) 12 input_27[0][0]
__________________________________________________________________________________________________
lstm_27 (LSTM) (None, 15360, 1) 12 input_28[0][0]
__________________________________________________________________________________________________
lstm_28 (LSTM) (None, 15360, 1) 12 input_29[0][0]
__________________________________________________________________________________________________
lstm_29 (LSTM) (None, 15360, 1) 12 input_30[0][0]
__________________________________________________________________________________________________
lstm_30 (LSTM) (None, 15360, 1) 12 input_31[0][0]
__________________________________________________________________________________________________
lstm_31 (LSTM) (None, 15360, 1) 12 input_32[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 15360, 32) 0 lstm[0][0]
lstm_1[0][0]
lstm_2[0][0]
lstm_3[0][0]
lstm_4[0][0]
lstm_5[0][0]
lstm_6[0][0]
lstm_7[0][0]
lstm_8[0][0]
lstm_9[0][0]
lstm_10[0][0]
lstm_11[0][0]
lstm_12[0][0]
lstm_13[0][0]
lstm_14[0][0]
lstm_15[0][0]
lstm_16[0][0]
lstm_17[0][0]
lstm_18[0][0]
lstm_19[0][0]
lstm_20[0][0]
lstm_21[0][0]
lstm_22[0][0]
lstm_23[0][0]
lstm_24[0][0]
lstm_25[0][0]
lstm_26[0][0]
lstm_27[0][0]
lstm_28[0][0]
lstm_29[0][0]
lstm_30[0][0]
lstm_31[0][0]
__________________________________________________________________________________________________
lstm_32 (LSTM) (None, 15360, 128) 82432 concatenate[0][0]
__________________________________________________________________________________________________
lstm_33 (LSTM) (None, 64) 49408 lstm_32[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 6) 390 lstm_33[0][0]
==================================================================================================
Total params: 132,614
Trainable params: 132,614
Non-trainable params: 0
__________________________________________________________________________________________________
现在因为模型需要形状为(channel, batch, timesteps, 1)
的输入,所以我们必须在输入模型之前重新排序数据集的轴,下面的示例代码向您展示如何将轴从(batch, timesteps, channel)
重新排序到(channel, batch, timesteps, 1)
:
import numpy as np
batch_size = 64
timesteps = 15360
channels_num = 32
x = np.random.rand(batch_size, timesteps, channels_num)
print(x.shape)
x = np.moveaxis(x, -1, 0)[..., np.newaxis]
print(x.shape)
x = [i for i in x]
print(x[0].shape)
输出:
(64, 15360, 32)
(32, 64, 15360, 1)
(64, 15360, 1)
【讨论】:
感谢您的回复。 我已经实现了您所说的所有内容,如下所示:但我收到以下错误:ValueError:Layer model_2 需要 32 个输入,但它接收到 1 个输入张量。收到的输入:[x = [i for i in x]
,我会更新答案,检查您是否有其他问题
好的,现在它工作得很好,没有任何问题。谢谢@Mr.ForExample以上是关于基于“Channel LSTM”的基于 LSTM 的 EEG 信号分类架构的主要内容,如果未能解决你的问题,请参考以下文章
LSTM预测基于matlab贝叶斯网络改进LSTM预测含Matlab源码 1158期