连接 Keras 中的屏蔽输入

Posted

技术标签:

【中文标题】连接 Keras 中的屏蔽输入【英文标题】:Concatenate masked inputs in Keras 【发布时间】:2021-09-18 02:19:31 【问题描述】:

使用 Keras API,我正在尝试构建一个神经网络模型,如下所示。我有两个输入,每个输入都是分类时间序列,它们都已转换为 one-hots。在一个批次中,时间序列通常具有不同的长度,所以我用零填充它们到一个共同的长度。我想首先对每个输入应用一个掩蔽层以忽略填充,然后对每个输入应用 TimeDistributed Dense 层,最后在将结果传递到 LSTM 之前连接 Dense 层的输出。 (这并不重要,在 LSTM 之后应用最终的线性 Dense 层。)像这样:

from tensorflow.keras.layers import Input,Dense,Concatenate
from tensorflow.keras.layers import TimeDistributed,LSTM,Masking
from tensorflow.keras import Model

input1=Input(shape=(None,5),batch_size=batch_size) #five categories for feature 1
input2=Input(shape=(None,3),batch_size=batch_size) # three categories for feature 2
masked1=Masking(mask_value=0,input_shape=(None, 5))(input1)
masked2=Masking(mask_value=0,input_shape=(None, 3))(input2)
dense1=TimeDistributed(Dense(16,activation='relu'))(masked1)
dense2=TimeDistributed(Dense(16,activation='relu'))(masked2)
concat=TimeDistributed(Concatenate(axis=-1))([dense1,dense2])
lstm=LSTM(512,activation='tanh',return_sequences=True,stateful=False)(concat)
out=TimeDistributed(Dense(5,activation='linear'))(lstm)
model=Model(inputs=[input1,input2],outputs=out)

但是,连接屏蔽输入在 Keras 中似乎不起作用;与 Concatenate 层的行会导致错误 AttributeError: 'list' object has no attribute 'shape'

谁能建议一种方法来完成我正在尝试做的事情或类似的事情?

【问题讨论】:

我认为你不需要用于连接层的时间分布层包装器。 哇,这实际上解决了问题,哈哈 【参考方案1】:

来自 cmets

from tensorflow.keras.layers import Input,Dense,Concatenate
from tensorflow.keras.layers import TimeDistributed,LSTM,Masking
from tensorflow.keras import Model

input1=Input(shape=(None,5),batch_size=batch_size) #five categories for feature 1
input2=Input(shape=(None,3),batch_size=batch_size) # three categories for feature 2
masked1=Masking(mask_value=0,input_shape=(None, 5))(input1)
masked2=Masking(mask_value=0,input_shape=(None, 3))(input2)
dense1=TimeDistributed(Dense(16,activation='relu'))(masked1)
dense2=TimeDistributed(Dense(16,activation='relu'))(masked2)
concat=Concatenate(axis=-1)([dense1,dense2])
lstm=LSTM(512,activation='tanh',return_sequences=True,stateful=False)(concat)
out=Dense(5,activation='linear')(lstm)
model=Model(inputs=[input1,input2],outputs=out)

输出:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(32, None, 5)]      0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(32, None, 3)]      0                                            
__________________________________________________________________________________________________
masking (Masking)               (32, None, 5)        0           input_1[0][0]                    
__________________________________________________________________________________________________
masking_1 (Masking)             (32, None, 3)        0           input_2[0][0]                    
__________________________________________________________________________________________________
time_distributed (TimeDistribut (32, None, 16)       96          masking[0][0]                    
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (32, None, 16)       64          masking_1[0][0]                  
__________________________________________________________________________________________________
concatenate (Concatenate)       (32, None, 32)       0           time_distributed[0][0]           
                                                                 time_distributed_1[0][0]         
__________________________________________________________________________________________________
lstm (LSTM)                     (32, None, 512)      1116160     concatenate[0][0]                
__________________________________________________________________________________________________
dense_2 (Dense)                 (32, None, 5)        2565        lstm[0][0]                       
==================================================================================================
Total params: 1,118,885
Trainable params: 1,118,885
Non-trainable params: 0
_____________________________________________________________________________

(转述自 Kaveh)

【讨论】:

以上是关于连接 Keras 中的屏蔽输入的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Keras 的 BLSTM 中屏蔽填充?

如何在 Keras 中屏蔽损失函数(mae)?

如何屏蔽rs232针脚17未连接引起的告警

用星号屏蔽python中的用户输入

jquery屏蔽数字输入

PLC与变频器屏蔽线怎么连接