Keras 中的多分类预测不止一个?

Posted

技术标签:

【中文标题】Keras 中的多分类预测不止一个?【英文标题】:More than one prediction in multi-classification in Keras? 【发布时间】:2018-01-12 12:56:57 【问题描述】:

我正在学习如何使用 Keras 设计卷积神经网络。我开发了一个使用 VGG16 作为基础的简单模型。我在数据集中有大约 6 类图像。这是我的模型的代码和描述。

model = models.Sequential()
conv_base = VGG16(weights='imagenet' ,include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
conv_base.trainable = False
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(6, activation='sigmoid'))

这是编译和拟合模型的代码:

model.compile(loss='categorical_crossentropy',
        optimizer=optimizers.RMSprop(lr=1e-4),
         metrics=['acc'])
model.summary()

callbacks = [
    EarlyStopping(monitor='acc', patience=1, mode='auto'),
    ModelCheckpoint(monitor='val_loss', save_best_only=True, filepath=model_file_path)
]

history = model.fit_generator(
    train_generator,
    steps_per_epoch=10,
    epochs=EPOCHS,
    validation_data=validation_generator,
    callbacks = callbacks,
    validation_steps=10)

这是预测新图像的代码

img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
plt.figure(index)
imgplot = plt.imshow(img)

x = image.img_to_array(img)
x = x.reshape((1,) + x.shape)
prediction = model.predict(x)[0]
# print(prediction)

通常 model.predict() 方法预测不止一个类。

[0 1 1 0 0 0]

我有几个问题

    多类分类模型预测多个输出是否正常? 如果预测不止一个类,如何在训练期间测量准确度? 如何修改神经网络以便只预测一个类别?

感谢任何帮助。非常感谢!

【问题讨论】:

【参考方案1】:

你不是在做多类分类,而是多标签。这是由在输出层使用 sigmoid 激活引起的。要正确进行多类分类,请在输出处使用 softmax 激活,这将产生类的概率分布。 正如预期的那样,采用具有最大概率 (argmax) 的类将产生单个类预测。

【讨论】:

感谢您的解释。就是这样!你能解释一下 Keras 训练函数是如何测量准确性的吗?例如,在训练期间,假设模型预测 [0, 0.2, 0.4, 0.7, 0.1, 0]。 keras 看第 3 类,精度为 0.7 并认为是模型的输出与真值进行比较? 这种方法不是库特定的。一般的多类分类概率是使用具有n个输出类的softmax激活,将“pick”作为最高概率之一。因此,在您的情况下,是的,第 3 类被认为是选定的类。单个样本的准确度是二进制的,并在您的输入上取平均值。 @TMS。

以上是关于Keras 中的多分类预测不止一个?的主要内容,如果未能解决你的问题,请参考以下文章

Keras 中的多步多元时间序列分类

Keras中具有二进制分类的多标签

Keras 中具有类权重的多标签分类

使用 Keras 进行多类图像分类的多重预测

有没有一种方法可以使用多标签分类,但当模型仅预测 keras 中的一个标签时认为是正确的?

Keras - 带权重的多标签分类