keras中model.compile的参数'weighted_metrics'和model.fit_generator的参数'class_weight'之间的区别?
Posted
技术标签:
【中文标题】keras中model.compile的参数\'weighted_metrics\'和model.fit_generator的参数\'class_weight\'之间的区别?【英文标题】:Difference between model.compile's parameter 'weighted_metrics' and model.fit_generator's parameter 'class_weight' in keras?keras中model.compile的参数'weighted_metrics'和model.fit_generator的参数'class_weight'之间的区别? 【发布时间】:2019-11-14 00:41:35 【问题描述】:在训练用于图像分类的 keras 模型(来自 DOG BREED IDENTIFICATION 数据集 KAGGLE 的 120 个类)时,我需要使用我在某处阅读的类权重来平衡类,在示例中我看到人们使用 fit_generator 的参数 class_weight。但我在 model.compile 中发现了另一个参数 weighted_metrics,其在文档中的描述是:“在训练和测试期间由 sample_weight 或 class_weight 评估和加权的指标列表”。我要使用这个吗?请举例说明这个参数的用途。
#Calculating Class weights
counter = Counter(train_generator.classes)
max_value = float(max(counter.values()))
CLASS_WEIGHTS = classid: max_value / num_occurences
for classid, num_occurences in counter.items()
# Model Compile
model.compile(optimizer=Adam(lr=LR),
loss=categorical_crossentropy,
metrics=[categorical_accuracy],
weighted_metrics=None) # <--------------- This parameter
STEPS_PER_EPOCH = train_generator.n//train_generator.batch_size
VAL_STEPS = val_generator.n//val_generator.batch_size
model.fit_generator(train_generator,
steps_per_epoch=STEPS_PER_EPOCH,
epochs=EPOCHS,
callbacks=callback_list,
verbose=1,
class_weight=CLASS_WEIGHTS,
validation_data=val_generator,
validation_steps=VAL_STEPS) # USED CLASS_WEIGHTS HERE
【问题讨论】:
【参考方案1】:是的,您可以将它们用于您的不平衡数据集。
加权指标
是考虑到的指标列表
类权重
你传入 fit_generator。
所以在你的例子中,你可以设置
weighted_metrics=['accuracy']
和
class_weight = 0 : 3, 1: 4
weighted_metrics 参数的目的是给出一个指标列表,该列表将考虑您在 fit_generator 中传递的 class_weights。
【讨论】:
在这种情况下您建议使用哪些指标,我添加了 categorical_crossentropy 和 categorical_accuracy。 class_weights 是否会自动应用于计算损失函数,还是我必须在 model.compile 的加权指标列表中手动添加?感谢您及时回复问题。 @HimanshuTanwani 您可以使用此方法计算权重datascience.stackexchange.com/a/18722 Class_weights 应自动应用于标准损失函数集,例如,如果您使用 loss="categorical_crossentropy"。这些指标仅供您查看,它们不会影响训练,只有损失会影响。 还要注意,权重用于损失函数优化,以及度量评估(仅在训练数据上),请参阅:keras.io/api/models/model_training_apis/#fit-method"class_weight: 可选字典映射类索引(整数) 到权重(浮点)值,用于加权损失函数(仅在训练期间)。"以上是关于keras中model.compile的参数'weighted_metrics'和model.fit_generator的参数'class_weight'之间的区别?的主要内容,如果未能解决你的问题,请参考以下文章
model.compile() 是不是初始化 Keras(tensorflow 后端)中的所有权重和偏差?
keras model.compile(loss='目标函数 ', optimizer='adam', metrics=['accuracy'])(代码