如何使用训练有素的 BERT 模型检查点进行预测?

Posted

技术标签:

【中文标题】如何使用训练有素的 BERT 模型检查点进行预测?【英文标题】:How to use trained BERT model checkpoints for prediction? 【发布时间】:2019-11-10 01:24:09 【问题描述】:

我使用 SQUAD 2.0 训练了 BERT,并使用 BERT-master/run_squad.py 在输出目录中获得了 model.ckpt.datamodel.ckpt.metamodel.ckpt.index(F1 分数:81)以及 predictions.json

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

我尝试将model.ckpt.metamodel.ckpt.indexmodel.ckpt.data 复制到$BERT_LARGE_DIR 目录并更改run_squad.py 标志如下,以仅预测答案而不使用数据集进行训练:

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
  --do_train=False \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

抛出bucket directory/model.ckpt不存在错误。

如何利用训练后生成的检查点进行预测?

【问题讨论】:

【参考方案1】:

通常,训练的检查点是在训练时在--output_dir参数指定的目录中创建的。 (在你的情况下是gs://some_bucket/squad_large/)。每个检查点都会有一个编号。你必须找出最大的数字;例如:model.ckpt-12345。现在,在您的评估/预测中设置--init_checkpoint 参数,使用输出目录和最后保存的检查点(编号最高的模型)。 (在你的情况下,它应该类似于--init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>

【讨论】:

【参考方案2】:

在第二个代码中,FLAG init_checkpoint 我认为应该是:

--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt

和上面一样,而不是--init_checkpoint=$BERT_LARGE_DIR/model.ckpt

如果问题仍然存在,您是否使用multi_cased_L-12_H-768_A-12 预训练模型?

【讨论】:

我使用的是 cased_L-24_H-1024_A-16 预训练模型。我会让你知道结果。 这并没有加载训练好的模型,而是预训练好的模型。其他答案有效。要使用经过训练的模型,我们必须指定检查点编号。

以上是关于如何使用训练有素的 BERT 模型检查点进行预测?的主要内容,如果未能解决你的问题,请参考以下文章

BERT:深度双向预训练语言模型

预训练语言模型(GPT,BERT)

中文bert wwm 预训练参考笔记

如何从检查点使用 tf.estimator.Estimator 进行预测?

使用 BERT 等预训练模型进行文档分类

《自然语言处理实战入门》深度学习 ---- 预训练模型的使用(ALBERT 进行多标签文本分类与微调 fine tune)