如何在 CNN-LSTM 模型上应用 model.fit() 函数?
Posted
技术标签:
【中文标题】如何在 CNN-LSTM 模型上应用 model.fit() 函数?【英文标题】:How to apply model.fit() function over an CNN-LSTM model? 【发布时间】:2020-09-20 12:22:56 【问题描述】:我正在尝试使用它来将图像分为两类。我也应用了 model.fit() 函数,但它显示错误。
ValueError: 形状为 (90, 1) 的目标数组被传递为形状 (None, 10) 的输出,同时用作损失 binary_crossentropy。这种损失期望目标具有与输出相同的形状。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D, LSTM
import pickle
import numpy as np
X = np.array(pickle.load(open("X.pickle","rb")))
Y = np.array(pickle.load(open("Y.pickle","rb")))
#scaling our image data
X = X/255.0
model = Sequential()
model.add(Conv2D(64 ,(3,3), input_shape = (300,300,1)))
# model.add(MaxPooling2D(pool_size = (2,2)))
model.add(tf.keras.layers.Reshape((16, 16*512)))
model.add(LSTM(128, activation='relu', return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))
opt = tf.keras.optimizers.Adam(lr=1e-3, decay=1e-5)
model.compile(loss='binary_crossentropy', optimizer=opt,
metrics=['accuracy'])
# model.summary()
model.fit(X, Y, batch_size=32, epochs = 2, validation_split=0.1)
【问题讨论】:
训练数据的形状y
必须等于模型的输出
Y 形状是 (90,),致密层形状是 (None,10) 所以我需要申请 Y.reshape(90,10)?
【参考方案1】:
如果您的问题是明确的,那么您的问题是您使用的是 binary_crossentropy
而不是 categorical_crossentropy
;确保您确实有分类问题而不是二元分类问题。
另外,请注意,如果您的标签是简单的整数格式,例如 [1,2,3,4...] 并且不是 one-hot-encoded,您的 loss_function 应该是 sparse_categorical_crossentropy
,而不是 categorical_crossentropy
。
如果你确实有二进制分类问题,就像上面的错误中所说的,请确保:
-
损失是 binary_crossentroy +
Dense(1,activation='sigmoid')
损失是 categorical_crossentropy + Dense(2,activation='softmax')
【讨论】:
非常感谢。它工作得很好,但是你怎么知道我需要使用 categorical_crossentropy 因为我的直觉是使用这个模型进行二元分类,比如区分猫和狗的图像。还有可能从中得到混淆矩阵吗? 您需要明确问题定义:如果您有多个类,那么它是 softmax + categorical/sparse_categorical。你把 10 个神经元放在最后一级......这就是我推断你有 10 个类而不是 2 个的方式。如果你确实有 1 个类作为输出,请确保使用 binary_crossentropy + sigmoid。 我更新了答案以反映您所有可能的需求。 非常感谢。是否可以打印 CNN-LSTM 模型的混淆矩阵? 从 TF 2.2 开始,您可以使用内置的混淆矩阵:tensorflow.org/api_docs/python/tf/math/confusion_matrix。以上是关于如何在 CNN-LSTM 模型上应用 model.fit() 函数?的主要内容,如果未能解决你的问题,请参考以下文章
基于注意力机制的CNN-LSTM模型及其应用(含软硬注意力区别)
在 keras 中使用 CNN-LSTM 模型进行序列到序列分类
使用 Keras 训练 CNN-LSTM 时卡在第一个 epoch