如何在脚本中加载 tflite 模型?
Posted
技术标签:
【中文标题】如何在脚本中加载 tflite 模型?【英文标题】:How to load a tflite model in script? 【发布时间】:2018-10-30 18:56:44 【问题描述】:我已使用 bazel 将 .pb
文件转换为 tflite
文件。现在我想在我的 python 脚本中加载这个tflite
模型只是为了测试天气这是否给了我正确的输出?
【问题讨论】:
【参考方案1】:在 Python 中使用 TensorFlow lite 模型:
TensorFlow Lite 的冗长功能强大,因为它允许您进行更多控制,但在许多情况下您只想传递输入并获得输出,因此我制作了一个包装此逻辑的类:
以下适用于 tfhub.dev 中的分类模型,例如:https://tfhub.dev/tensorflow/lite-model/mobilenet_v2_1.0_224/1/metadata/1
# Usage
model = TensorflowLiteClassificationModel("path/to/model.tflite")
(label, probability) = model.run_from_filepath("path/to/image.jpeg")
import tensorflow as tf
import numpy as np
from PIL import Image
class TensorflowLiteClassificationModel:
def __init__(self, model_path, labels, image_size=224):
self.interpreter = tf.lite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self._input_details = self.interpreter.get_input_details()
self._output_details = self.interpreter.get_output_details()
self.labels = labels
self.image_size=image_size
def run_from_filepath(self, image_path):
input_data_type = self._input_details[0]["dtype"]
image = np.array(Image.open(image_path).resize((self.image_size, self.image_size)), dtype=input_data_type)
if input_data_type == np.float32:
image = image / 255.
if image.shape == (1, 224, 224):
image = np.stack(image*3, axis=0)
return self.run(image)
def run(self, image):
"""
args:
image: a (1, image_size, image_size, 3) np.array
Returns list of [Label, Probability], of type List<str, float>
"""
self.interpreter.set_tensor(self._input_details[0]["index"], image)
self.interpreter.invoke()
tflite_interpreter_output = self.interpreter.get_tensor(self._output_details[0]["index"])
probabilities = np.array(tflite_interpreter_output[0])
# create list of ["label", probability], ordered descending probability
label_to_probabilities = []
for i, probability in enumerate(probabilities):
label_to_probabilities.append([self.labels[i], float(probability)])
return sorted(label_to_probabilities, key=lambda element: element[1])
注意
但是,您需要对其进行修改以支持不同的用例,因为我将图像作为输入传递,并获得分类([标签,概率])输出。如果您需要文本输入 (NLP) 或其他输出(对象检测输出边界框、标签和概率)、分类(仅标签)等。
此外,如果您希望输入不同大小的图像,那么您必须更改输入大小并重新分配模型 (self.interpreter.allocate_tensors()
)。这很慢(效率低下)。最好使用平台大小调整功能(例如 android 图形库)而不是使用 TensorFlow lite 模型来进行大小调整。或者,您可以使用单独的模型来调整模型的大小,这样allocate_tensors()
会更快。
【讨论】:
【参考方案2】:您可以使用 TensorFlow Lite Python 解释器在 python shell 中加载 tflite 模型,并使用您的输入数据对其进行测试。
代码会是这样的:
import numpy as np
import tensorflow as tf
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
以上代码来自TensorFlow Lite官方指南,更多详细信息,请阅读this。
【讨论】:
使用了哪个 tensorflow 版本?口译员现在不在场。 正如我刚刚使用 tensorflow 1.14.0 测试的那样,tflite Interpreter 已从 tf.contrib.lite.Interpreter 移至 tf.lite.Interpreter,请参阅上面的更新答案。 这真的很棒。我修改了文件以实际测试图像,我发现我的 .tflite 文件一定是无效的。如果你熟悉对象检测,可以看看***.com/questions/59736600/…吗? 如何在测试数据上测试而不是随机数据 我们如何对所有数据集进行预测?像“.predict(x_test)”?以上是关于如何在脚本中加载 tflite 模型?的主要内容,如果未能解决你的问题,请参考以下文章