CNN培训多标签分类---不起作用

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CNN培训多标签分类---不起作用相关的知识,希望对你有一定的参考价值。

尝试预测纹理图像的标签,图像可以包含两个标签,如['带状','条纹'],但大多数只有一个标签。

输出精度非常高....第一个时期可以有0.96 acc ...但是预测数组都接近于0,这是错误的,必须至少有一个数字与1相关。

有人能帮我吗?谢谢!!

这是代码

Input image = (read by opencv)/255
Multi-labels = First LabelEncoder convert to numbers, then keras.to_categorical

然后我建立了一个CNN模型如下

X_train, X_test, y_train, y_test = train_test_split(img_array, test_value, test_size=0.1)

model = Sequential()
model.add(Conv2D(filters=64, kernel_size=(5, 5), padding='Same', data_format='channels_last', activation='relu',
                 input_shape=(300, 300, 3)))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Conv2D(filters=32, kernel_size=(3, 3), padding='Same', activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(300, init ='uniform',activation='relu'))
model.add(Dense(285, init = 'uniform',activation='sigmoid'))
model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])


history = model.fit(X_train, y_train, batch_size= 24, epochs=10, validation_split=0.15)
答案

如果您的模型只有2个标签,那么最后一层应该是

model.add(Dense(2, init = 'uniform',activation='sigmoid'))

但是,您的班级不平衡也会影响准确性。如果您的班级不平衡太高,您的模型将显示95%+培训,验证和测试准确性,但个人准确度仍然很低,并且该模型不适用于真实世界数据。

使用以下内容可以了解详细和基于类的准确性:

from sklearn.metrics import classification_report



X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30)
X_test1, X_valid, y_test1, y_valid = train_test_split(X_test, y_test, test_size=0.30)
model.fit(X_train, y_train, batch_size=64, epochs=8, shuffle=True, validation_data=(X_test1,y_test1), callbacks=[metrics])

Y_TEST = np.argmax(y_valid, axis=1)
y_pred = model.predict_classes(X_valid)

print("#"*50,"\n",classification_report(Y_TEST, y_pred))

请分享您的课程分布以便进一步了解。

另一答案

不确定为什么Dense层中的神经元数量是285.如果有47个类别,那么Dense层的输出神经元应该是47.此外,使用像he_normal而不是uniform的内核初始化器。 https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py

model.add(Dense(47, activation='sigmoid'))
model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])

这是一个包含5个类的多标签分类示例。

https://github.com/suraj-deshmukh/Keras-Multi-Label-Image-Classification

以上是关于CNN培训多标签分类---不起作用的主要内容,如果未能解决你的问题,请参考以下文章

Keras CNN:图像的多标签分类

用于多标签图像分类的 CNN

多标签文本分类HFT-CNN: Learning Hierarchical Category Structure for Multi-label Short Text Categorization

多标签文本分类Deep Learning for Extreme Multi-label Text Classification

为多标签分类生成 sklearn 指标的问题

使用 PyTorch 的多标签、多类图像分类器 (ConvNet)