如何在具有使用@tf.keras.utils.register_keras_serializable 注册的自定义函数的 Tensorflow Serving 中提供模型?
Posted
技术标签:
【中文标题】如何在具有使用@tf.keras.utils.register_keras_serializable 注册的自定义函数的 Tensorflow Serving 中提供模型?【英文标题】:How to serve model in Tensorflow Serving that has custom function registered with @tf.keras.utils.register_keras_serializable? 【发布时间】:2021-08-27 02:21:12 【问题描述】:我正在使用keras 预处理层:
@tf.keras.utils.register_keras_serializable(package='Custom', name=None)
def transform_domain(inputs):
""" Lowercase domain and remove possible leading `www.`"""
inputs = tf.strings.lower(inputs)
return tf.strings.regex_replace(inputs, '^www\.', '', replace_global=False)
pre_layer_domain = TextVectorization(
standardize=transform_domain, ...)
# The above layer is added to my model somewhere
# ...
model.fit(...)
model.save('out_path')
到目前为止一切都很好。问题是当我尝试加载模型时,如果我的自定义函数不存在,则会出现异常:
# In a new Python interpreter
model = tf.keras.models.load_model('out_path')
>>> RuntimeError:
Unable to restore a layer of class TextVectorization. Layers of class
TextVectorization require that the class be provided to the model loading
code, either by registering the class using @keras.utils.register_keras_serializable
on the class def and including that file in your program, or by passing the
class in a keras.utils.CustomObjectScope that wraps this load call.
所以该消息暗示了两件事:
在训练模型时在函数上使用@keras.util.register_keras_serializable +have that function loaded in the runtime while loading the model
在加载模型时使用keras.utils.CustomObjectScope
上下文管理器
太棒了,这两个选项都很可爱而且很棒——前提是我可以控制模型的加载方式。但是在Tensorflow Serving
中加载模型时我该怎么做呢?
我尝试通过几种不同的方式将具有该功能的 Python 模块添加到 out_path/assets
目录中。无法弄清楚如何导出该函数,因此当我加载模型时它会以某种方式自动加载。
【问题讨论】:
【参考方案1】:经过一些实验,我实际上可以使用任一方法导出模型:
model.save('out_path/dummy_model/1')
# or
tf.saved_model.save(model, 'out_path/dummy_model/1')
当我将该模型加载到 Tensorflow Serving 中时,“它就可以工作。”我的自定义函数已作为图表的一部分加载,无需进一步工作:
docker run --rm -p 8501:8501 \
--mount type=bind,source=$(pwd)/out_path/,target=/app \
-e MODEL_BASE_PATH=/app \
-e MODEL_NAME=dummy_model -t tensorflow/serving
上面示例的适当端点上的推理 (http://localhost:8501/v1/models/dummy_model:predict) 工作并返回与我在 Python 环境中查询模型时获得的相同值,即使在 Python 领域,我需要使用上面发布的原始错误消息提供的两个建议之一来加载模型(在我的情况下这不是问题,因为当我在 Python 中加载模型时,我可以完全控制环境和加载模型的代码——通常是笔记本或其他用于检查模型内部的脚本):
with tf.keras.utils.CustomObjectScope('transform_domain': transform_domain):
new_model = tf.keras.models.load_model('dummy_model/1')
【讨论】:
以上是关于如何在具有使用@tf.keras.utils.register_keras_serializable 注册的自定义函数的 Tensorflow Serving 中提供模型?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Java 中使用 SAX Parser 检查 xml 标签是不是具有属性?
如何在 Swift 中使用 Alamofire 与 Multipart 一起使用具有不同键和多种参数的多个图像