在 Android 上使用来自冻结的张量流图的变量

Posted

技术标签:

【中文标题】在 Android 上使用来自冻结的张量流图的变量【英文标题】:Use Variables from frozen tensorflow graph on Android 【发布时间】:2018-11-17 13:22:20 【问题描述】:

TLDR:如何在 Android 上使用冻结的张量流图中的变量?


1.我想做什么

我有一个 Tensorflow 模型,它在多个变量中保持内部状态,创建于:state_var = tf.Variable(tf.zeros(shape, dtype=tf.float32), name='state', trainable=False)

此状态在推理期间被修改:

tf.assign(state_var, new_value)

我现在想在 android 上部署模型。我能够让 Tensorflow 示例应用程序运行。在那里,加载了一个冻结的模型,效果很好。


2.从冻结图恢复变量不起作用

但是,当您使用freeze_graph script 冻结图形时,所有变量都将转换为常量。这对于网络的权重来说很好,但对于内部状态则不然。推理失败并显示以下消息。我将此解释为“assign 不适用于常量张量”

java.lang.RuntimeException: Failed to load model from 'file:///android_asset/model.pb'
at org.tensorflow.contrib.android.TensorFlowInferenceInterface.<init>(TensorFlowInferenceInterface.java:113)
...
Caused by: java.io.IOException: Not a valid TensorFlow Graph serialization: Input 0 of node layer_1/Assign was passed float from layer_1/state:0 incompatible with expected float_ref.

幸运的是,您可以将被转换为常量的变量列入黑名单。但是,这也不起作用,因为冻结的图形现在包含未初始化的变量。

java.lang.IllegalStateException: Attempting to use uninitialized value layer_7/state

3.恢复 SavedModel 在 Android 上不起作用

我尝试过的最后一个版本是使用SavedModel 格式,它应该包含冻结图和变量。不幸的是,调用 restore 方法在 Android 上不起作用。

SavedModelBundle bundle = SavedModelBundle.load(modelFilename, modelTag);

// produces error:

E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.demo, PID: 27451
     java.lang.UnsupportedOperationException: Loading a SavedModel is not supported in Android. File a bug at https://github.com/tensorflow/tensorflow/issues if this feature is important to you at org.tensorflow.SavedModelBundle.load(Native Method)

4.我怎样才能做到这一点?

我不知道我还能尝试什么。这是我的想象,但我不知道如何使它工作:

    找出在 Android 上初始化变量的方法 找出另一种冻结模型的方法,以便初始化程序操作可能也是冻结图的一部分,并且可以从 Android 运行 了解 RNN/LSTM 是否/如何在内部实现,因为它们也应该具有在推理期间使用变量的相同要求(我假设 LSTM 能够部署在 Android 上)。 ???

【问题讨论】:

你想对变量做什么?如果它只是在推理过程中使用的局部变量,那么您可以使用控制依赖项 (with tf.control_dependencies([var.assign(initial_value)]):) 对其进行初始化。然后你在控制依赖块中放置的任何东西都将在变量初始化后运行。但这不适用于保持模型的状态。如果您想这样做,您需要在单独的运行调用中初始化变量(例如,从检查点恢复它或提供初始值)。另一种选择是提供值而不是使用变量。 感谢您的评论。我走了你建议的最后一条路线(“提供值而不是使用变量”),因为这是我发现在 Android 上工作的最佳解决方案。我已经用更多细节自己回答了这个问题 【参考方案1】:

我自己通过走不同的路线解决了这个问题。据我所知,“变量”概念在 Android 上的使用方式与我在 Python 中使用的方式不同(例如,您无法初始化变量,然后在推理期间更新网络的内部状态)。

相反,您可以使用占位符和输出节点来保存 Java 代码中的状态,并在每次推理调用时将其提供给网络。

将所有出现的tf.Variable 替换为tf.placeholder。形状保持不变。 我还定义了一个用于读取输出的附加节点。 (也许你可以简单地阅读占位符本身,我还没有尝试过。)tf.identity(inputs, name='state_output')

在 Android 上进行推理期间,您将初始状态输入网络。

float[] values = 0, 0, 0, ...; // zeros of the correct shape inferenceInterface.feed('state', values, ...);

推理后,您读取网络的最终内部状态

float[] values = new float[output_shape]; inferenceInterface.fetch('state_output', values);

然后,您可以记住 Java 中的此输出,以将其传递到 'state' 占位符以进行下一次推理调用。

【讨论】:

以上是关于在 Android 上使用来自冻结的张量流图的变量的主要内容,如果未能解决你的问题,请参考以下文章

是否可以在没有训练操作的情况下可视化张量流图?

pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?

Tensorflow(4) 张量属性:维数、形状、数据类型

Tensorflow瞎搞

tensorflow中张量(tensor)的属性——维数(阶)形状和数据类型

tensorflow零碎知识