Tf Lite 模型图像分类打印标签
Posted
技术标签:
【中文标题】Tf Lite 模型图像分类打印标签【英文标题】:Printing label for Tf Lite model Image Classification 【发布时间】:2021-12-20 18:18:17 【问题描述】:我正在开发一个 Image Claasification TF Lite 模型,以使用此 link 检测人脸的蒙版或没有蒙版。我按照链接在顶点 AI 中训练了图像多类分类并下载了 TF lite 模型。模型的标签是“mask”和“no_mask”。为了测试模型,我写了如下代码:
interpret= tf.lite.Interpreter(model_path="<FILE_PATH>")
input= interpret.get_input_details()
output= interpret.get_output_details()
interpret.allocate_tensors()
pprint(input)
pprint(output)
data= cv2.imread("file.jpeg")
new_image= cv2.resize(data,(224,224))
interpret.resize_tensor_input(input[0]["index"],[1,224,224,3])
interpret.allocate_tensors()
interpret.set_tensor(input[0]["index"],[new_image])
interpret.invoke()
result= interpret.get_tensor(output[0]['index'])
print (" Prediction is - ".format(result))
将此代码用于我的一张图片会给我的结果是:
[[30 246]]
现在我也想在结果中打印标签。例如:
面具:30
no_mask:46
有什么方法可以实现吗?
请帮忙,因为我是 TF Lite 的新手
【问题讨论】:
你看过带有 netron.app 的模型来验证它输出 2 个标签吗?我看不到输出。 我检查了 netron 应用程序。它正在输出两个标签。数组为 [1,2]。 那么 [[30 246]] 是什么? [[30 246]] 是预测的输出。我提交了一张用于预测 mask 或 no_mask 的图像,因此其中一个值是 mask 预测,另一个是 no_mask 预测。我需要找出应该打印的即标签。例如:[[掩码:30,No_mask:246]]。注意:值 30 和 246 是中间结果。置信度值或概率将通过将这些值除以 255 来计算。因此实际概率将为:[[ 0.11, 0.96]] 我想你自己已经回答了这个问题。使用 python 获取值。标签是您在开始时设置的任何内容。检查原始模型,看看什么是 30,什么是 246。 【参考方案1】:我自己解决了。从 Vertex AI 下载的 .tflite 模型包含名为“dict.txt”的标签文件,其中包含所有标签。检查 GCP 文档 here。要获取此标签文件,我们首先需要解压缩 .tflite 文件,该文件将为我们提供 dict.txt。如需更多信息,请查看the tflite documentation 和how to read associate file from the models。
之后我从the github link label.py引用了以下代码:
import argparse
import time
import numpy as np
from PIL import Image
import tensorflow as tf
interpret= tf.lite.Interpreter(model_path="<FILE_PATH>")
input= interpret.get_input_details()
output= interpret.get_output_details()
interpret.allocate_tensors()
pprint(input)
pprint(output)
data= cv2.imread("file.jpeg")
new_image= cv2.resize(data,(224,224))
interpret.resize_tensor_input(input[0]["index"],[1,224,224,3])
interpret.allocate_tensors()
interpret.set_tensor(input[0]["index"],[new_image])
interpret.invoke()
floating_model= input[0]['dtype'] == np.float32
op_data= interpret.get_tensor(output[0]['index'])
result= np.squeeze(op_data)
top_k=result.agrsort()[-5:][::1]
labels=load_labels("dict.txt")
for i in top_k:
if floating_model:
print(':08.6f: '.format(float(result[i]), labels[i]))
else:
print(':08.6f: '.format(float(result[i] / 255.0), labels[i]))
【讨论】:
以上是关于Tf Lite 模型图像分类打印标签的主要内容,如果未能解决你的问题,请参考以下文章