在 Android 上使用来自冻结的张量流图的变量
Posted
技术标签:
【中文标题】在 Android 上使用来自冻结的张量流图的变量【英文标题】:Use Variables from frozen tensorflow graph on Android 【发布时间】:2018-11-17 13:22:20 【问题描述】:TLDR:如何在 Android 上使用冻结的张量流图中的变量?
1.我想做什么
我有一个 Tensorflow 模型,它在多个变量中保持内部状态,创建于:state_var = tf.Variable(tf.zeros(shape, dtype=tf.float32), name='state', trainable=False)
。
此状态在推理期间被修改:
tf.assign(state_var, new_value)
我现在想在 android 上部署模型。我能够让 Tensorflow 示例应用程序运行。在那里,加载了一个冻结的模型,效果很好。
2.从冻结图恢复变量不起作用
但是,当您使用freeze_graph script 冻结图形时,所有变量都将转换为常量。这对于网络的权重来说很好,但对于内部状态则不然。推理失败并显示以下消息。我将此解释为“assign 不适用于常量张量”
java.lang.RuntimeException: Failed to load model from 'file:///android_asset/model.pb'
at org.tensorflow.contrib.android.TensorFlowInferenceInterface.<init>(TensorFlowInferenceInterface.java:113)
...
Caused by: java.io.IOException: Not a valid TensorFlow Graph serialization: Input 0 of node layer_1/Assign was passed float from layer_1/state:0 incompatible with expected float_ref.
幸运的是,您可以将被转换为常量的变量列入黑名单。但是,这也不起作用,因为冻结的图形现在包含未初始化的变量。
java.lang.IllegalStateException: Attempting to use uninitialized value layer_7/state
3.恢复 SavedModel 在 Android 上不起作用
我尝试过的最后一个版本是使用SavedModel
格式,它应该包含冻结图和变量。不幸的是,调用 restore 方法在 Android 上不起作用。
SavedModelBundle bundle = SavedModelBundle.load(modelFilename, modelTag);
// produces error:
E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.tensorflow.demo, PID: 27451
java.lang.UnsupportedOperationException: Loading a SavedModel is not supported in Android. File a bug at https://github.com/tensorflow/tensorflow/issues if this feature is important to you at org.tensorflow.SavedModelBundle.load(Native Method)
4.我怎样才能做到这一点?
我不知道我还能尝试什么。这是我的想象,但我不知道如何使它工作:
-
找出在 Android 上初始化变量的方法
找出另一种冻结模型的方法,以便初始化程序操作可能也是冻结图的一部分,并且可以从 Android 运行
了解 RNN/LSTM 是否/如何在内部实现,因为它们也应该具有在推理期间使用变量的相同要求(我假设 LSTM 能够部署在 Android 上)。
???
【问题讨论】:
你想对变量做什么?如果它只是在推理过程中使用的局部变量,那么您可以使用控制依赖项 (with tf.control_dependencies([var.assign(initial_value)]):
) 对其进行初始化。然后你在控制依赖块中放置的任何东西都将在变量初始化后运行。但这不适用于保持模型的状态。如果您想这样做,您需要在单独的运行调用中初始化变量(例如,从检查点恢复它或提供初始值)。另一种选择是提供值而不是使用变量。
感谢您的评论。我走了你建议的最后一条路线(“提供值而不是使用变量”),因为这是我发现在 Android 上工作的最佳解决方案。我已经用更多细节自己回答了这个问题
【参考方案1】:
我自己通过走不同的路线解决了这个问题。据我所知,“变量”概念在 Android 上的使用方式与我在 Python 中使用的方式不同(例如,您无法初始化变量,然后在推理期间更新网络的内部状态)。
相反,您可以使用占位符和输出节点来保存 Java 代码中的状态,并在每次推理调用时将其提供给网络。
将所有出现的tf.Variable
替换为tf.placeholder
。形状保持不变。
我还定义了一个用于读取输出的附加节点。 (也许你可以简单地阅读占位符本身,我还没有尝试过。)tf.identity(inputs, name='state_output')
在 Android 上进行推理期间,您将初始状态输入网络。
float[] values = 0, 0, 0, ...; // zeros of the correct shape
inferenceInterface.feed('state', values, ...);
推理后,您读取网络的最终内部状态
float[] values = new float[output_shape];
inferenceInterface.fetch('state_output', values);
然后,您可以记住 Java 中的此输出,以将其传递到 'state'
占位符以进行下一次推理调用。
【讨论】:
以上是关于在 Android 上使用来自冻结的张量流图的变量的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?