TensorFlow,为啥保存模型后有3个文件?
Posted
技术标签:
【中文标题】TensorFlow,为啥保存模型后有3个文件?【英文标题】:TensorFlow, why there are 3 files after saving the model?TensorFlow,为什么保存模型后有3个文件? 【发布时间】:2017-05-07 00:14:46 【问题描述】:阅读了docs,我在TensorFlow
中保存了一个模型,这是我的演示代码:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
但在那之后,我发现有 3 个文件
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
而且我无法通过恢复model.ckpt
文件来恢复模型,因为没有这样的文件。这是我的代码
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
那么,为什么会有 3 个文件?
【问题讨论】:
你知道如何解决这个问题吗?如何再次加载模型(使用 Keras)? 【参考方案1】:试试这个:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
saver.restore(sess, "/tmp/model.ckpt")
TensorFlow save 方法保存三种文件,因为它存储 图形结构 与 变量值 分开。 .meta
文件描述了保存的图结构,所以需要在恢复checkpoint之前导入(否则不知道保存的checkpoint值对应什么变量)。
或者,您可以这样做:
# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Now load the checkpoint variable values
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")
即使没有名为model.ckpt
的文件,您在恢复时仍会使用该名称引用已保存的检查点。来自saver.py
source code:
用户只需要与用户指定的前缀交互...而不是 任何物理路径名。
【讨论】:
所以不使用 .index 和 .data 吗?那这两个文件什么时候用呢? @ajfbiw.s .meta 存储图结构,.data 存储图中每个变量的值,.index 标识检查点。所以在上面的例子中:import_meta_graph 使用 .meta,saver.restore 使用 .data 和 .index 哦,我明白了。谢谢。 您是否有机会使用与加载模型不同的 TensorFlow 版本保存模型? (github.com/tensorflow/tensorflow/issues/5639) 有谁知道00000
和00001
的数字是什么意思?在variables.data-?????-of-?????
文件中【参考方案2】:
元文件:描述保存的图结构,包括GraphDef、SaverDef等;然后应用tf.train.import_meta_graph('/tmp/model.ckpt.meta')
,将恢复Saver
和Graph
。
索引文件:它是一个字符串字符串不可变表(tensorflow::table::Table)。每个键是一个张量的名称,其值是一个序列化的 BundleEntryProto。每个 BundleEntryProto 都描述了张量的元数据:哪些“数据”文件包含张量的内容、该文件的偏移量、校验和、一些辅助数据等。
数据文件:是 TensorBundle 集合,保存所有变量的值。
【讨论】:
我得到了用于图像分类的 pb 文件。我可以将它用于实时视频分类吗? 你能告诉我,使用 Keras 2,如果模型保存为 3 个文件,我该如何加载?【参考方案3】:我正在从Word2Vec tensorflow 教程中恢复经过训练的词嵌入。
如果您创建了多个检查点:
例如创建的文件如下所示
model.ckpt-55695.data-00000-of-00001
model.ckpt-55695.index
model.ckpt-55695.meta
试试这个
def restore_session(self, session):
saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
saver.restore(session, './tmp/model.ckpt-55695')
调用 restore_session() 时:
def test_word2vec():
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
with tf.device("/cpu:0"):
model = Word2Vec(opts, session)
model.restore_session(session)
model.get_embedding("assistance")
【讨论】:
“model.ckpt-55695.data-00000-of-00001”中的“00000-of-00001”是什么意思? @hafiz031 后缀.data-00000-of-00001
引用了 Tensorflow 在多台机器上训练的场景中使用的分片。对于单台机器上的训练,您将只有这个后缀。【参考方案4】:
例如,如果你训练了一个带有 dropout 的 CNN,你可以这样做:
def predict(image, model_name):
"""
image -> single image, (width, height, channels)
model_name -> model file that was saved without any extensions
"""
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./' + model_name + '.meta')
saver.restore(sess, './' + model_name)
# Substitute 'logits' with your model
prediction = tf.argmax(logits, 1)
# 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
return prediction.eval(feed_dict=x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0)
【讨论】:
以上是关于TensorFlow,为啥保存模型后有3个文件?的主要内容,如果未能解决你的问题,请参考以下文章
为啥我的 tensorflow 模型输出在 x 个时期后变为 NaN?