bert-as-service输出分类结果
Posted zyl007
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了bert-as-service输出分类结果相关的知识,希望对你有一定的参考价值。
bert-as-service: Mapping a variable-length sentence to a fixed-length vector using BERT model
默认情况下bert-as-service只提供固定长度的特征向量,如果想要直接获取分类预测结果呢?
bert提供了的run_classifier.py
以训练分类模型,同时bert提供了离线评估的方法。
一些可能的部署思路
- bert基于tensorflow实现,可以参考tensorflow-serving对外提供部署服务
- 参考bert代码修改离线接口为在线推断,基于flask/django提供部署服务
- 修改bert-as-service提供高效在线预测服务
bert-as-service的强大可以参考:Serving Google BERT in Production using Tensorflow and ZeroMQ
修改bert-as-service提供分类预测
思路:https://github.com/hanxiao/bert-as-service/issues/213
bert-as-service 默认情况下,不会加载分类层
- 加载模型的同时加载分类层的权重和bias
- 添加分类层
在graph.py#L79中添加
if args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
hidden_size = 768
output_weights = tf.get_variable(
"output_weights", [args.num_labels, hidden_size],
)
output_bias = tf.get_variable(
"output_bias", [args.num_labels])
tvars = tf.trainable_variables()
注意:在加载权重和bias的时候不要定义初始化方法,否则会从初始化方法进行加载,而不是微调模型。
elif args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
# pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1)
logits = tf.matmul(pooled, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
pooled = tf.nn.softmax(logits, axis=-1)
以上是关于bert-as-service输出分类结果的主要内容,如果未能解决你的问题,请参考以下文章
Java异常捕获之一道try-catch-finally语句题
try-catch-finally 中哪个部分可以省略?try-catch-finally 中,如果 catch 中 return 了,finally 还会执行吗?