在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型
Posted
技术标签:
【中文标题】在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型【英文标题】:Loading pretrained model with Python tensorflow v.0.9.0 in Java tensorflow v.1.2.0 【发布时间】:2017-07-04 18:49:20 【问题描述】:当 Java 和 Python tensorflow 版本都是 1.2.0 时,似乎我们可以使用 SavedModelBundle (Java) 和 Saved Model API (Python) 将训练好的模型保存在 Python tensorflow 中并在 Java tensorflow 中加载模型(不是与 Maven)。
但是,当 Python 版本低于 1.0 时,我找不到在 Java 中正确加载模型的方法。
我训练了一个模型并将其保存为 Python tensorflow (0.9.0) 中的 .pb、.sd 和 .txt 文件,并按照 tensorflow 网站中的 example 指令加载模型。但是,我收到以下错误:
Exception in thread "main" java.lang.IllegalStateException: Attempting
to use uninitialized value policy/mean_network/hidden_1/b
[[Node: _retval_policy/mean_network/hidden_1/b_0_0 =
_Retval[T=DT_FLOAT, index=0,
_device="/job:localhost/replica:0/task:0/cpu:0"]
(policy/mean_network/hidden_1/b)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at Carpole.executeGraph(Carpole.java:42)
at Carpole.main(Carpole.java:30)
有谁知道如何在不使用 Saved Model API 的情况下在最新版本的 Java 中正确加载预训练模型(因为我再也找不到以前版本的 API)?
提前致谢!
这是我用于保存的 Python 代码:
with tf.Session() as sess:
self.saver = tf.train.Saver(tf.all_variables())
sess.run(tf.initialize_all_variables())
…..
saver_def = self.saver.as_saver_def()
print(saver_def.filename_tensor_name)
print(saver_def.restore_op_name)
self.saver.save(sess, 'trained_model'+str(itr)+'.sd')
tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.pb', as_text=False)
tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.txt', as_text=True)
这是我的 Java 代码
public static void main(String[] args) throws Exception
String dataDirPath = args[0];
byte[] graphDef = readAllBytesOrExit(Paths.get(dataDirPath, "trained_model10.pb"));
List<String> labels = readAllLinesOrExit(Paths.get(dataDirPath, "trained_model10.txt"));
float[] vector = new float[4];
vector[0] = (float) -0.09341373;
vector[1] = (float) -0.07540844;
vector[2] = (float) 0.00930138;
vector[3] = (float) -0.14317159;
Tensor input = Tensor.create(vector);
float[] labelProbabilities = executeGraph(graphDef, input);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
private static float[] executeGraph(byte[] graphDef, Tensor input_tensor)
try (Graph g = new Graph())
g.importGraphDef(graphDef);
System.out.println(g);
try (Session s = new Session(g); Tensor result = s.runner().feed("policy/mean_network/input/input",input_tensor).fetch("policy/mean_network/hidden_1/b").run().get(0))
final long[] rshape = result.shape();
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
【问题讨论】:
【参考方案1】:一般而言,不保证 1.0 之前的 TensorFlow 版本可以与 >= 1.0 的 TensorFlow 版本一起使用(根据 TensorFlow Version Semantics 和语义版本控制的使用)。
也就是说,查看您提供的代码 sn-ps,您正在加载 Java 中的计算图,但没有加载保存的变量,因此您会收到一个异常,抱怨变量尚未初始化。
将图形和保存的变量封装到一个包中是 SavedModel 格式的用途。但是,如果您不能使用它,并且如果您只需要图形在 Java 中进行推理,那么您可能需要考虑“冻结”图形,然后在 Java 中加载它。冻结图将包含单个文件中的所有变量值。
您可以尝试使用 freeze_graph
库(来自 0.9 版本分支)来保存这样的图表。
使用 TensorFlow 1.0 之前的版本可能是一个挑战。如果可能,我强烈建议您将模型移至 TensorFlow 版本 >= 1.0,然后可以利用 API 稳定性保证。
希望对您有所帮助。
【讨论】:
谢谢!但是,不知何故,我的 tensorflow 没有 tensorflow/python/tools 模块(我使用的是 anaconda2)。我最终升级到 v.1.1.0,稍微修改了我的代码并使用了 Saved Model API。以上是关于在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型的主要内容,如果未能解决你的问题,请参考以下文章
干货使用TensorFlow官方Java API调用TensorFlow模型(附代码)
用 Python 生成的 Tensorflow 数据集在 Tensorflow Java API(标签图像)中有不同的读数
Java应用XV使用Java中的TensorFlow来构建和训练机器学习模型
在 android (Java 8) 上运行 TensorFlow Lite 时出现 java.lang.NoSuchMethodError