如何计算 BERT 中多类分类的所有召回准确率和 f1 度量?

Posted

技术标签:

【中文标题】如何计算 BERT 中多类分类的所有召回准确率和 f1 度量?【英文标题】:How can i calculate all recall accuracy precision and f1 measure for multi class classification in BERT? 【发布时间】:2021-06-28 05:52:13 【问题描述】:
from sklearn.metrics import f1_score

def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = v: k for k, v in label_dict.items()
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: label_dict_inverse[label]')
        print(f'Accuracy: len(y_preds[y_preds==label])/len(y_true)\n')

需要计算多类模型的分类报告,但它只给出准确率和 f1 分数

【问题讨论】:

我不完全理解您的问题。你试过sklearn's classification report吗? 是的,我也需要显示混淆矩阵 那么 sklearn 的分类报告有什么问题呢?它输出召回率、准确率和 f 分数 是文本分类问题,包含10个多类。 你真的尝试过使用sklearns的classification_report吗?输出是什么?有没有错误?您的问题不清楚,需要澄清一下。 【参考方案1】:

我想你正在使用 Pytorch 环境。这是打印数据集中每个类的 F1、召回率和精度的正确代码。如果您有经过训练的模型,请加载它以及要测试的数据集。

from sklearn.metrics import classification_report, confusion_matrix

val_dataset = LoadDataset('/content/val.csv')
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=51) # Load the data

model.load_state_dict(torch.load('vit-base.bin')) # Load the trained model
model.cuda()                                      # For putting model on GPUs
with torch.no_grad():
 image,target = next(iter(val_loader))
 image = image.to(device)
 target = target.flatten().to(device)
 prediction = model(image)

prediction = prediction.argmax(dim=1).view(target.size()).cpu().numpy()
target = target.cpu().numpy()
print(classification_report(target,prediction,target_names=val_dataset.LE.classes_)) # LE is the label encoder

【讨论】:

以上是关于如何计算 BERT 中多类分类的所有召回准确率和 f1 度量?的主要内容,如果未能解决你的问题,请参考以下文章

多分类问题的准确率,召回率怎么计算

牢记分类指标:准确率、精确率、召回率、F1 score以及ROC

从 CSV 多类数据集中计算精度和召回率。

【基础概念】准确率和召回率

准确率(Precision)、召回率(Recall)、F值(F-Measure)

精度评定中的准确率(Precision)和召回率(Recall)