恢复预训练模型的 TensorFlow 检查点文件

Posted

技术标签:

【中文标题】恢复预训练模型的 TensorFlow 检查点文件【英文标题】:Restoring Tensorflow checkpoint files of a pre-trained model 【发布时间】:2018-02-16 19:47:19 【问题描述】:

我已经从MobileNet 检查点文件下载了一个预训练的 TF-Slim 模型,我正在尝试查看与层相关的权重。

例如,我有三个文件:

67903136 Jun 14 00:15 mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
19954 Jun 14 00:15 mobilenet_v1_1.0_224.ckpt.index
4319476 Jun 14 00:15 mobilenet_v1_1.0_224.ckpt.meta

第一种方法:

我直接使用tensorboard

tensorboard --logdir=$CHKPNT_DIR

它在本地运行 (http://127.0.0.1:6006/),但不显示任何内容:

No dashboards are active for the current data set.
Probable causes:

You haven’t written any data to your event files.
TensorBoard can’t find your event files.

第二种方法:

我用的是后端方法Event_Accumulator

import tensorflow as tf
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
event_acc = EventAccumulator('$CHKPNT_DIR')
event_acc.Reload()

# Show all tags in the log file
print(event_acc.Tags())

有趣的是,所有的标签都是空的:

'scalars': [], 'histograms': [], 'meta_graph': False, 'images': [], 'graph': False, 'audio': [], 'distributions': [], 'tensors': [], 'run_metadata': []

所以这是dir

>>>dir(event_acc)
['Audio', 'CompressedHistograms', 'FirstEventTimestamp', 'Graph', 'Histograms', 'Images', 'MetaGraph', 'PluginAssets', 'PluginTagToContent', 'Reload', 'RetrievePluginAsset', 'RunMetadata', 'Scalars', 'SummaryMetadata', 'Tags', 'Tensors', '_CheckForOutOfOrderStepAndMaybePurge', '_CheckForRestartAndMaybePurge', '_CompressHistogram', '_ConvertHistogramProtoToTuple', '_MaybePurgeOrphanedData', '_ProcessAudio', '_ProcessEvent', '_ProcessHistogram', '_ProcessImage', '_ProcessScalar', '_ProcessTensor', '_Purge', '__class__', '__delattr__', '__dict__', '__doc__', '__format__', '__getattribute__', '__hash__', '__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_compression_bps', '_first_event_timestamp', '_generator', '_generator_mutex', '_graph', '_graph_from_metagraph', '_meta_graph', '_plugin_to_tag_to_content', '_tagged_metadata', '_tensor_summaries', 'accumulated_attrs', 'audios', 'compressed_histograms', 'file_version', 'histograms', 'images', 'most_recent_step', 'most_recent_wall_time', 'path', 'purge_orphaned_data', 'scalars', 'summary_metadata', 'tensors']

那么,我们应该如何查看已经预训练的网络?这些检查点文件必须包含某种 Google Protobuf 数据。

在 Mac OS 10.12.4 上运行 TF 1.3.0。

【问题讨论】:

【参考方案1】:
saver = tf.train.import_meta_graph("path/your/meta/file")
saver.restore("path/to/data/file")

graph = tf.get_default_graph()
writer = tf.summary.FileWriter("path/to/write/graph")
writer.add_graph(graph)

然后

tensorboard --logdir="path/to/write/graph"

【讨论】:

我收到以下错误:ValueError: No op named SSTableReaderV2 in defined operations. 这似乎是一个常见问题:github.com/tensorflow/models/issues/1564 你能看到权重@Jie

以上是关于恢复预训练模型的 TensorFlow 检查点文件的主要内容,如果未能解决你的问题,请参考以下文章

AI - TensorFlow - 示例05:保存和恢复模型

Tensorflow:在具有不同类别数量的新数据集上微调预训练模型

如何在 Tensorflow 对象检测 api 中评估预训练模型

Dataset_factory importerror:Tensorflow 从自定义数据的现有检查点微调预训练模型

tensorflow saver 保存和恢复指定 tensor

tensorflow saver 保存和恢复指定 tensor