TensorFlow-Lite 预训练模型在 Android 演示中不起作用
Posted
技术标签:
【中文标题】TensorFlow-Lite 预训练模型在 Android 演示中不起作用【英文标题】:Tensorflow-Lite pretrained model does not work in Android demo 【发布时间】:2018-05-25 11:35:29 【问题描述】:Tensorflow-Lite android 演示使用它提供的原始模型:mobilenet_quant_v1_224.tflite。见:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
他们还在这里提供了其他预训练的精简模型:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md
但是,我从上面的链接下载了一些较小的模型,例如 mobilenet_v1_0.25_224.tflite,并在演示应用程序中将原始模型替换为此模型,只需更改 ImageClassifier.java
中的 MODEL_PATH = "mobilenet_v1_0.25_224.tflite";
.应用程序崩溃:
12-11 12:52:34.222 17713-17729/? E/AndroidRuntime:致命异常: 相机背景 进程:android.example.com.tflitecamerademo,PID:17713 java.lang.IllegalArgumentException:无法获取输入尺寸。 第 0 个输入应该有 602112 字节,但找到了 150528 字节。 在 org.tensorflow.lite.NativeInterpreterWrapper.getInputDims(Native 方法) 在 org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:82) 在 org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:112) 在 org.tensorflow.lite.Interpreter.run(Interpreter.java:93) 在 com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108) 在 com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) 在 com.example.android.tflitecamerademo.Camera2BasicFragment.access$900(Camera2BasicFragment.java:69) 在 com.example.android.tflitecamerademo.Camera2BasicFragment$5.run(Camera2BasicFragment.java:558) 在 android.os.Handler.handleCallback(Handler.java:751) 在 android.os.Handler.dispatchMessage(Handler.java:95) 在 android.os.Looper.loop(Looper.java:154) 在 android.os.HandlerThread.run(HandlerThread.java:61)
原因似乎是模型所需的输入尺寸是图像尺寸的四倍。所以我将DIM_BATCH_SIZE = 1
修改为DIM_BATCH_SIZE = 4
。现在的错误是:
致命异常:CameraBackground 进程:android.example.com.tflitecamerademo,PID:18241 java.lang.IllegalArgumentException:无法转换 TensorFlowLite 将 FLOAT32 类型的张量转换为 [[B 类型的 Java 对象(即 与 TensorFlowLite 类型 UINT8 兼容) 在 org.tensorflow.lite.Tensor.copyTo(Tensor.java:36) 在 org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:122) 在 org.tensorflow.lite.Interpreter.run(Interpreter.java:93) 在 com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108) 在 com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) 在 com.example.android.tflitecamerademo.Camera2BasicFragment.access$900(Camera2BasicFragment.java:69) 在 com.example.android.tflitecamerademo.Camera2BasicFragment$5.run(Camera2BasicFragment.java:558) 在 android.os.Handler.handleCallback(Handler.java:751) 在 android.os.Handler.dispatchMessage(Handler.java:95) 在 android.os.Looper.loop(Looper.java:154) 在 android.os.HandlerThread.run(HandlerThread.java:61)
我的问题是如何让简化的 MobileNet tflite 模型与 TF-lite Android Demo 一起使用。
(我实际上尝试了其他方法,例如使用提供的工具将 TF 冻结图转换为 TF-lite 模型,即使使用与 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md 中完全相同的示例代码,但转换后的 tflite 模型仍然无法在 Android 中运行演示。)
【问题讨论】:
能否请您在帖子正文(不仅仅是标题)中提出一个明确的问题?请查看this。 请注意,我也遇到了这种情况。奇怪的是,当我将这些相同的再训练模型放入 Poets 2 Lite 的 TensorFlow 演示应用程序中时(它与 OP 引用的 Tensorflow-Android Lite 演示共享大量代码。github.com/googlecodelabs/tensorflow-for-poets-2/tree/master/… 【参考方案1】:Tensorflow-Lite Android 演示中包含的 ImageClassifier.java 需要一个量化模型。截至目前,只有一种 Mobilenets 模型以量化形式提供:Mobilenet 1.0 224 Quant。
要使用其他浮点模型,请从 Tensorflow for Poets TF-Lite 演示源中换入 ImageClassifier.java。这是为 float 模型编写的。 https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tflite/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
做一个比较,你会发现在实现上有几个重要的区别。
另一个需要考虑的选项是使用 TOCO 将浮点模型转换为量化模型: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
【讨论】:
非常感谢。诗人的 TensorFlow 有效!对于您的第二个选项,我之前将一些在 Android TF 演示中工作的 MobileNet 模型量化为 .pb 文件。但是现在如果我使用 TOCO 将这种量化模型转换为 .tflite 模型,在转换过程中会出现一些问题。我认为致命的一个是 Unsupported TensorFlow op: Dequantize。你也经历过吗? 是的,我还没有将量化模型转换为 .tflite 的运气 现在有适用于所有 depth_mulitpliers 和图像大小的量化 mobilenet v1 模型。 github.com/tensorflow/models/blob/master/research/slim/nets/… @AshEldritch“将浮点模型转换为使用 TOCO 量化”需要在传递给 Toco 之前使用量化节点对图进行检测。【参考方案2】:我也遇到了与 Seedling 相同的错误。 我为 Mobilenet Float 模型创建了一个新的图像分类器包装器。 现在工作正常。您可以直接在图像分类器演示中添加该类,并使用它在 Camera2BasicFragment 中创建分类器
classifier = new ImageClassifierFloatMobileNet(getActivity());
下面是 Mobilenet Float 模型的图像分类器类包装器
/**
* This classifier works with the Float MobileNet model.
*/
public class ImageClassifierFloatMobileNet extends ImageClassifier
/**
* An array to hold inference results, to be feed into Tensorflow Lite as outputs.
* This isn't part of the super class, because we need a primitive array here.
*/
private float[][] labelProbArray = null;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
/**
* Initializes an @code ImageClassifier.
*
* @param activity
*/
public ImageClassifierFloatMobileNet(Activity activity) throws IOException
super(activity);
labelProbArray = new float[1][getNumLabels()];
@Override
protected String getModelPath()
// you can download this file from
// https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
// return "mobilenet_quant_v1_224.tflite";
return "retrained.tflite";
@Override
protected String getLabelPath()
// return "labels_mobilenet_quant_v1_224.txt";
return "retrained_labels.txt";
@Override
public int getImageSizeX()
return 224;
@Override
public int getImageSizeY()
return 224;
@Override
protected int getNumBytesPerChannel()
// the Float model uses a 4 bytes
return 4;
@Override
protected void addPixelValue(int val)
imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
@Override
protected float getProbability(int labelIndex)
return labelProbArray[0][labelIndex];
@Override
protected void setProbability(int labelIndex, Number value)
labelProbArray[0][labelIndex] = value.byteValue();
@Override
protected float getNormalizedProbability(int labelIndex)
return labelProbArray[0][labelIndex];
@Override
protected void runInference()
tflite.run(imgData, labelProbArray);
【讨论】:
以上是关于TensorFlow-Lite 预训练模型在 Android 演示中不起作用的主要内容,如果未能解决你的问题,请参考以下文章