《BERT源码分析PART III》
Posted cx2016
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了《BERT源码分析PART III》相关的知识,希望对你有一定的参考价值。
BERT源码分析PART III
写在前面
为了方便查阅,我将完整的BERT源码分析整理成了PDF版本,可以在微信公众号NewBeeNLP后台直接下载。
继续之前没有介绍完的Pre-training部分,在上一篇中我们已经完成了对输入数据的处理,接下来看看BERT是怎么完成Masked LM和Next Sentence Prediction两个任务的训练的。
任务#1:Masked LM
get_masked_lm_output
函数用于计算任务#1的训练loss。输入为BertModel的最后一层sequence_output输出([batch_size, seq_length, hidden_size]),因为对一个序列的MASK标记的预测属于标注问题,需要整个sequence的输出状态。
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
"""Get loss and log probs for the masked LM."""
# 获取mask词的encode
input_tensor = gather_indexes(input_tensor, positions)
with tf.variable_scope("cls/predictions"):
# 在输出之前添加一个非线性变换,只在预训练阶段起作用
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# output_weights是和传入的word embedding一样的
# 这里再添加一个bias
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
# label_ids表示mask掉的Token的id
label_ids = tf.reshape(label_ids, [-1])
label_weights = tf.reshape(label_weights, [-1])
one_hot_labels = tf.one_hot(
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
# 但是由于实际MASK的可能不到20,比如只MASK18,那么label_ids有2个0(padding)
# 而label_weights=[1, 1, ...., 0, 0],说明后面两个label_id是padding的,计算loss要去掉。
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
numerator = tf.reduce_sum(label_weights * per_example_loss)
denominator