TensorFlow Lite 模型 Android:找不到要标记的轴。要标记的有效轴的大小应大于 1

Posted

技术标签:

【中文标题】TensorFlow Lite 模型 Android:找不到要标记的轴。要标记的有效轴的大小应大于 1【英文标题】:TensorFlow Lite Model Android: Cannot find an axis to label. A valid axis to label should have size larger than 1 【发布时间】:2020-11-14 04:07:30 【问题描述】:

我正在尝试在 android 应用程序中使用预训练的 TensorFlow Lite 模型。

我已从here

下载了TensorFlow Lite 的图像分类示例应用程序

我已更改所有四个模型分类器文件中的以下代码

protected String getModelPath() 
   // you can download this file from
   // see build.gradle for where to obtain this file. It should be auto
   // downloaded into assets.
   //return "mobilenet_v1_1.0_224_quant.tflite";
  return "model_23072020.tflite";

我使用的 TensorFlow Lite 模型是用于图像分类的预训练模型。基本上它会扫描图像并生成 0 或 1 的输出。0 表示图像质量不佳,1 表示图像质量良好。

模型具有动态范围量化。

当我运行应用程序并打印 outputProbabilityBuffer.getFloatArray() 的值时,我得到以下结果

I/tensorflow:分类器:输出值[F@e08d0d3

我正在尝试使用以下代码记录值

tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());


Map<String, Float> labeledProbability = new HashMap<>();
labeledProbability.put("abc", 93.556f);

// Added logger for displaying value in console
LOGGER.i("value of output %s ", outputProbabilityBuffer.getFloatArray());

更新

我删除了上面的记录器,现在我在这一行遇到了异常

Map<String, Float> labeledProbability = new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
        .getMapWithFloatValue();

收到的错误是找不到要标记的轴。要标记的有效轴的大小应大于 1

下面提到了完整的堆栈跟踪

java.lang.IllegalArgumentException:找不到要标记的轴。一种 标记的有效轴的大小应大于 1。在 org.tensorflow.lite.support.label.TensorLabel.getFirstAxisWithSizeGreaterThanOne(TensorLabel.java:214) 在 org.tensorflow.lite.support.label.TensorLabel.(TensorLabel.java:105) 在 org.tensorflow.lite.examples.classification.tflite.Classifier.recognizeImage(Classifier.java:263) 在 org.tensorflow.lite.examples.classification.ClassifierTest.classificationResultsShouldNotChange(ClassifierTest.java:67) 在 java.lang.reflect.Method.invoke(Native Method) 在 org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) 在 org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) 在 org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) 在 org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) 在 androidx.test.internal.runner.junit4.statement.RunBefores.evaluate(RunBefores.java:80) 在 androidx.test.rule.ActivityTestRule$ActivityStatement.evaluate(ActivityTestRule.java:527) 在 org.junit.rules.RunRules.evaluate(RunRules.java:20) 在 org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) 在 org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) 在 org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) 在 org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) 在 org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) 在 org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) 在 org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) 在 org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) 在 org.junit.runners.ParentRunner.run(ParentRunner.java:363) 在 org.junit.runners.Suite.runChild(Suite.java:128) 在 org.junit.runners.Suite.runChild(Suite.java:27) 在 org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) 在 org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) 在 org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) 在 org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) 在 org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) 在 org.junit.runners.ParentRunner.run(ParentRunner.java:363) 在 org.junit.runner.JUnitCore.run(JUnitCore.java:137) 在 org.junit.runner.JUnitCore.run(JUnitCore.java:115) 在 androidx.test.internal.runner.TestExecutor.execute(TestExecutor.java:56) 在 androidx.test.runner.AndroidJUnitRunner.onStart(AndroidJUnitRunner.java:389) 在 android.app.Instrumentation$InstrumentationThread.run(Instrumentation.java:2075)

【问题讨论】:

请不要更改您的问题。如果原始问题解决了,请标记它。对于另一个问题,请打开一个新线程。这是因为如果其他人在寻找相同的问题,如果您更改上下文,他们将无法找到它。 当然,下次会记住这一点。 这里有什么进展吗?我不确定下面的答案是否真的回答了这个问题,只是建议了调试/记录错误的方法。想知道您是否找到了解决方案? 【参考方案1】:

您尝试记录一个对象。 outputProbabilityBuffer 是一个输出对象,你需要它的值。

在 Classifier 类中,尝试在第 332 行记录 entry.getKey() 和 entry.getValue()。这应该是标签名称和置信度。

还定义了一个 getter:getTitle() 例如它在第 522 行的 CameraActivity 中使用,您可以在其中记录输出。

【讨论】:

【参考方案2】:

问题在于您使用的是二元分类器,但示例 TensorFlowLite 图像识别代码要求您使用多类分类器(即,输出多个概率的分类器)。

你使用的模型只输出一个概率(你可以通过 outputProbabilityBuffer.getFloatArray() 得到),其中 0 是一个类,1 是另一个类。

要解决此问题,请勿使用示例用于处理输出并将标签映射到执行后的概率的代码 - 只需执行以下操作:

// after execution
float result = outputProbabilityBuffer.getFloatArray()[0];
if (result < 0.5) 
  // means prediction was for category corresponding to 0
 else 
  // means prediction was for category corresponding to 1

【讨论】:

以上是关于TensorFlow Lite 模型 Android:找不到要标记的轴。要标记的有效轴的大小应大于 1的主要内容,如果未能解决你的问题,请参考以下文章

sh 从tensorflow冷冻模型到tensorflow lite

将保存的tensorflow模型转换为tensorflow Lite的正确方法是啥

TensorFlow-Lite 预训练模型在 Android 演示中不起作用

使用 toco 将假量化 tensorflow 模型(.pb)转换为 tensorflow lite 模型(.tflite)失败

移动端目标识别(3)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之Running on mobile with TensorFlow Lite (写的很乱,回

如何创建可轻松转换为 TensorFlow Lite 的模型?