如何在Keras中使用fit_generator()来加权? [关闭]

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何在Keras中使用fit_generator()来加权? [关闭]相关的知识,希望对你有一定的参考价值。

我正在尝试使用keras来拟合CNN模型来对图像进行分类。数据集具有来自某些类的更多图像,因此其不平衡。

我在Keras中读到了如何权衡损失的不同之处,例如:https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras,这是很好的解释。但是,它总是解释fit()函数,而不是fit_generator()函数。

实际上,在fit_generator()函数中我们没有'class_weights'参数,而是我们有'weighted_metrics',我不理解它的描述:“weighted_metrics:在训练期间由sample_weight或class_weight评估和加权的指标列表测试“。

如何从“class_weights”传递到“weighted_metrics”?有人会有一个简单的例子吗?

答案

我们在class_weight(Keras v.2.2.2)中有fit_generator根据docs

Class_weight:可选字典将类索引(整数)映射到权重(浮点)值,用于加权损失函数(仅限训练期间)。这可以用来告诉模型“更多地关注”来自代表性不足的类的样本。

假设您有两个类[正面和负面],您可以将class_weight传递给fit_generator

model.fit_generator(gen,class_weight=[0.7,1.3])

以上是关于如何在Keras中使用fit_generator()来加权? [关闭]的主要内容,如果未能解决你的问题,请参考以下文章

Keras:网络不使用 fit_generator() 进行训练

keras中fit_generator()的优势

使用 keras.utils.Sequence 和 keras.model.fit_generator 时出现 KeyError。

如何在 keras fit_generator() 中定义 max_queue_size、workers 和 use_multiprocessing?

keras:为 fit_generator 使用 ImageDataGenerator 和 KFold 的问题

keras/scikit-learn:使用 fit_generator() 进行交叉验证