TensorFlow,为啥保存模型后有3个文件?

Posted

技术标签:

【中文标题】TensorFlow,为啥保存模型后有3个文件?【英文标题】:TensorFlow, why there are 3 files after saving the model?TensorFlow,为什么保存模型后有3个文件? 【发布时间】:2017-05-07 00:14:46 【问题描述】:

阅读了docs,我在TensorFlow中保存了一个模型,这是我的演示代码:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

但在那之后,我发现有 3 个文件

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

而且我无法通过恢复model.ckpt 文件来恢复模型,因为没有这样的文件。这是我的代码

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

那么,为什么会有 3 个文件?

【问题讨论】:

你知道如何解决这个问题吗?如何再次加载模型(使用 Keras)? 【参考方案1】:

试试这个:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

TensorFlow save 方法保存三种文件,因为它存储 图形结构变量值 分开。 .meta文件描述了保存的图结构,所以需要在恢复checkpoint之前导入(否则不知道保存的checkpoint值对应什么变量)。

或者,您可以这样做:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

即使没有名为model.ckpt 的文件,您在恢复时仍会使用该名称引用已保存的检查点。来自saver.py source code:

用户只需要与用户指定的前缀交互...而不是 任何物理路径名。

【讨论】:

所以不使用 .index 和 .data 吗?那这两个文件什么时候用呢? @ajfbiw.s .meta 存储图结构,.data 存储图中每个变量的值,.index 标识检查点。所以在上面的例子中:import_meta_graph 使用 .meta,saver.restore 使用 .data 和 .index 哦,我明白了。谢谢。 您是否有机会使用与加载模型不同的 TensorFlow 版本保存模型? (github.com/tensorflow/tensorflow/issues/5639) 有谁知道0000000001 的数字是什么意思?在variables.data-?????-of-????? 文件中【参考方案2】:

元文件:描述保存的图结构,包括GraphDef、SaverDef等;然后应用tf.train.import_meta_graph('/tmp/model.ckpt.meta'),将恢复SaverGraph

索引文件:它是一个字符串字符串不可变表(tensorflow::table::Table)。每个键是一个张量的名称,其值是一个序列化的 BundleEntryProto。每个 BundleEntryProto 都描述了张量的元数据:哪些“数据”文件包含张量的内容、该文件的偏移量、校验和、一些辅助数据等。

数据文件:是 TensorBundle 集合,保存所有变量的值。

【讨论】:

我得到了用于图像分类的 pb 文件。我可以将它用于实时视频分类吗? 你能告诉我,使用 Keras 2,如果模型保存为 3 个文件,我该如何加载?【参考方案3】:

我正在从Word2Vec tensorflow 教程中恢复经过训练的词嵌入。

如果您创建了多个检查点:

例如创建的文件如下所示

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

试试这个

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

调用 restore_session() 时:

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

【讨论】:

“model.ckpt-55695.data-00000-of-00001”中的“00000-of-00001”是什么意思? @hafiz031 后缀 .data-00000-of-00001 引用了 Tensorflow 在多台机器上训练的场景中使用的分片。对于单台机器上的训练,您将只有这个后缀。【参考方案4】:

例如,如果你训练了一个带有 dropout 的 CNN,你可以这样做:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict=x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0)

【讨论】:

以上是关于TensorFlow,为啥保存模型后有3个文件?的主要内容,如果未能解决你的问题,请参考以下文章

为啥我的 tensorflow 模型输出在 x 个时期后变为 NaN?

tensorflow模型的保存与恢复

Tensorflow加载预训练模型和保存模型

[ML] Tensorflow2 保存完整模型以及使用 HDF5

为啥重启flink程序后有很多in-progress文件?

Tensorflow:模型保存和服务