Keras多实例分类问题

Posted

技术标签:

【中文标题】Keras多实例分类问题【英文标题】:Keras multi instance classification problem 【发布时间】:2019-10-10 05:20:50 【问题描述】:

我正在尝试对一维信号数据集执行二进制分类任务。这是训练批次x 和真实批次y 的输入形状。换句话说,我有 16 个信号,每个批次有 38400 个时间步长和 1 个特征。每个信号被分类为一个类的 150 次(一种多实例学习问题)。

x 形状 (16, 38400, 1)

y 形状 (16, 150, 1)

这是我目前使用的网络:

model = Sequential()
model.add(LSTM(2, return_sequences=True, input_shape=params['input_shape'], dropout=0.5))
model.add(TimeDistributed(Dense(params["num_categories"])))
model.add(Activation('softmax' if params['num_categories'] != 1 else 'sigmoid'))

from keras.optimizers import Adam
optimizer = Adam(lr=params['learning_rate'], clipnorm=params.get("clipnorm", 1))
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

return model

请注意params['input_shape'] 等于[:, :, None]。以下是运行代码时网络层的输出形状:

Tensor("lstm_1/transpose_1:0", shape=(?, ?, 2), dtype=float32) : (1, 38400, 2)
Tensor("time_distributed_1/Reshape_1:0", shape=(?, ?, 1), dtype=float32) : (1, 38400, 1)
Tensor("activation_1/Sigmoid:0", shape=(?, ?, 1), dtype=float32) : (1, 38400, 1)

问题是我们遇到了这个错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [4,150,1] vs. [4,38400,1]
     [[node metrics/acc/Equal = Equal[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](_arg_activation_1_target_0_1/_65, metrics/acc/Round)]]
     [[node metrics/auroc/auc/assert_greater_equal/Assert/AssertGuard/Assert/Switch_2/_93 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1152_...t/Switch_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

我知道这个错误与y有关,它与网络输出的形状不同。但是,我不确定应该向网络添加什么样的层才能获得所需的形状。你能帮忙吗?

提前致谢。

【问题讨论】:

【参考方案1】:

您可以在网络中添加一个全连接层。如果您的网络的输出大小为 B,并且您希望其大小为 A,则可以定义一个全连接层来获取大小为 B 的输入并产生大小为 A 的输出。请参阅 tf.keras.layers 的文档。密集:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense

【讨论】:

以上是关于Keras多实例分类问题的主要内容,如果未能解决你的问题,请参考以下文章

多标签分类 Keras 指标

keras解决多标签分类问题

keras解决多标签分类问题

Keras CNN:图像的多标签分类

如何使用 keras 实现多标签分类神经网络

Keras - Precision 和 Recall 大于 1(多分类)