在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型

Posted

技术标签:

【中文标题】在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型【英文标题】:Loading pretrained model with Python tensorflow v.0.9.0 in Java tensorflow v.1.2.0 【发布时间】:2017-07-04 18:49:20 【问题描述】:

当 Java 和 Python tensorflow 版本都是 1.2.0 时,似乎我们可以使用 SavedModelBundle (Java) 和 Saved Model API (Python) 将训练好的模型保存在 Python tensorflow 中并在 Java tensorflow 中加载模型(不是与 Maven)。

但是,当 Python 版本低于 1.0 时,我找不到在 Java 中正确加载模型的方法。

我训练了一个模型并将其保存为 Python tensorflow (0.9.0) 中的 .pb、.sd 和 .txt 文件,并按照 tensorflow 网站中的 example 指令加载模型。但是,我收到以下错误:

Exception in thread "main" java.lang.IllegalStateException: Attempting 
to use uninitialized value policy/mean_network/hidden_1/b
            [[Node: _retval_policy/mean_network/hidden_1/b_0_0 = 
_Retval[T=DT_FLOAT, index=0, 
_device="/job:localhost/replica:0/task:0/cpu:0"]
(policy/mean_network/hidden_1/b)]]
            at org.tensorflow.Session.run(Native Method)
            at org.tensorflow.Session.access$100(Session.java:48)
            at org.tensorflow.Session$Runner.runHelper(Session.java:285)
            at org.tensorflow.Session$Runner.run(Session.java:235)
            at Carpole.executeGraph(Carpole.java:42)
            at Carpole.main(Carpole.java:30)

有谁知道如何在不使用 Saved Model API 的情况下在最新版本的 Java 中正确加载预训练模型(因为我再也找不到以前版本的 API)?

提前致谢!

这是我用于保存的 Python 代码:

with tf.Session() as sess:
    self.saver = tf.train.Saver(tf.all_variables())
    sess.run(tf.initialize_all_variables())
    …..
    saver_def = self.saver.as_saver_def()
    print(saver_def.filename_tensor_name)
    print(saver_def.restore_op_name)

    self.saver.save(sess, 'trained_model'+str(itr)+'.sd')
    tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.pb', as_text=False)
    tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.txt', as_text=True)

这是我的 Java 代码

public static void main(String[] args) throws Exception 
    String dataDirPath = args[0];
    byte[] graphDef = readAllBytesOrExit(Paths.get(dataDirPath, "trained_model10.pb"));
    List<String> labels = readAllLinesOrExit(Paths.get(dataDirPath, "trained_model10.txt"));
    float[] vector = new float[4];
    vector[0] = (float) -0.09341373;
    vector[1] = (float) -0.07540844;
    vector[2] = (float)  0.00930138;
    vector[3] = (float) -0.14317159;
    Tensor input = Tensor.create(vector);

    float[] labelProbabilities = executeGraph(graphDef, input);
    int bestLabelIdx = maxIndex(labelProbabilities);
    System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));


private static float[] executeGraph(byte[] graphDef, Tensor input_tensor) 
    try (Graph g = new Graph()) 
        g.importGraphDef(graphDef);
        System.out.println(g);
        try (Session s = new Session(g); Tensor result = s.runner().feed("policy/mean_network/input/input",input_tensor).fetch("policy/mean_network/hidden_1/b").run().get(0)) 
        final long[] rshape = result.shape();
        
        int nlabels = (int) rshape[1];
        return result.copyTo(new float[1][nlabels])[0];
    

【问题讨论】:

【参考方案1】:

一般而言,不保证 1.0 之前的 TensorFlow 版本可以与 >= 1.0 的 TensorFlow 版本一起使用(根据 TensorFlow Version Semantics 和语义版本控制的使用)。

也就是说,查看您提供的代码 sn-ps,您正在加载 Java 中的计算图,但没有加载保存的变量,因此您会收到一个异常,抱怨变量尚未初始化。

将图形和保存的变量封装到一个包中是 SavedModel 格式的用途。但是,如果您不能使用它,并且如果您只需要图形在 Java 中进行推理,那么您可能需要考虑“冻结”图形,然后在 Java 中加载它。冻结图将包含单个文件中的所有变量值。

您可以尝试使用 freeze_graph 库(来自 0.9 版本分支)来保存这样的图表。

使用 TensorFlow 1.0 之前的版本可能是一个挑战。如果可能,我强烈建议您将模型移至 TensorFlow 版本 >= 1.0,然后可以利用 API 稳定性保证。

希望对您有所帮助。

【讨论】:

谢谢!但是,不知何故,我的 tensorflow 没有 tensorflow/python/tools 模块(我使用的是 anaconda2)。我最终升级到 v.1.1.0,稍微修改了我的代码并使用了 Saved Model API。

以上是关于在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型的主要内容,如果未能解决你的问题,请参考以下文章

干货使用TensorFlow官方Java API调用TensorFlow模型(附代码)

用 Python 生成的 Tensorflow 数据集在 Tensorflow Java API(标签图像)中有不同的读数

Java应用XV使用Java中的TensorFlow来构建和训练机器学习模型

在 android (Java 8) 上运行 TensorFlow Lite 时出现 java.lang.NoSuchMethodError

在智能手机上使用 TensorFlow

Java / Tensorflow - API 调用 pb 模型使用 GPU 推理