如何将经过训练的 caffe 模型以 h5 格式加载到 c++ caffe 网络?

Posted

技术标签:

【中文标题】如何将经过训练的 caffe 模型以 h5 格式加载到 c++ caffe 网络?【英文标题】:How to load trained caffe model in h5 format to c++ caffe net? 【发布时间】:2017-07-09 07:38:28 【问题描述】:

正常训练的 caffe 模型是 .caffemodel 扩展名,实际上它们是 binary protobuf 格式。

知道如何在 C++ 中将 hdf5 格式的 caffe 模型加载到 caffe net 吗?

我有一个使用hdf5 格式的python caffe 训练的模型。

我的应用程序是在 c++ 中使用 caffe c++ 版本,我更喜欢使用 c++ 而不是 python。

如何将hdf5格式的caffe训练模型中的模型读取到c++ caffe net?

我知道caffe里面有hdf5data层。 有没有一个示例程序?

编辑:

我使用 CopyTrainedLayersFromHDF5() api 并得到以下运行时错误。

HDF5-DIAG: Error detected in HDF5 (1.8.11) thread 140737353775552:
  #000: ../../../src/H5G.c line 463 in H5Gopen2(): unable to open group
    major: Symbol table
    minor: Can't open object
  #001: ../../../src/H5Gint.c line 320 in H5G__open_name(): group not found
    major: Symbol table
    minor: Object not found
  #002: ../../../src/H5Gloc.c line 430 in H5G_loc_find(): can't find object
    major: Symbol table
    minor: Object not found
  #003: ../../../src/H5Gtraverse.c line 861 in H5G_traverse(): internal path traversal failed
    major: Symbol table
    minor: Object not found
  #004: ../../../src/H5Gtraverse.c line 641 in H5G_traverse_real(): traversal operator failed
    major: Symbol table
    minor: Callback failed
  #005: ../../../src/H5Gloc.c line 385 in H5G_loc_find_cb(): object 'data' doesn't exist
    major: Symbol table
    minor: Object not found
F0220 15:32:14.272573 24576 net.cpp:811] Check failed: data_hid >= 0 (-1 vs. 0) Error reading weights from model_800000.h5
*** Check failure stack trace: ***
    @     0x7ffff64afdcd  google::LogMessage::Fail()
    @     0x7ffff64b1d08  google::LogMessage::SendToLog()
    @     0x7ffff64af963  google::LogMessage::Flush()
    @     0x7ffff64b263e  google::LogMessageFatal::~LogMessageFatal()
    @     0x7ffff691c3a3  caffe::Net<>::CopyTrainedLayersFromHDF5()
    @           0x40828d  ExtractFeature::ExtractFeature()
    @           0x40ce78  main
    @     0x7ffff5bf8f45  __libc_start_main
    @           0x4080c9  (unknown)

Program received signal SIGABRT, Aborted.
0x00007ffff5c0dc37 in __GI_raise (sig=sig@entry=6)
    at ../nptl/sysdeps/unix/sysv/linux/raise.c:56
56  ../nptl/sysdeps/unix/sysv/linux/raise.c: No such file or directory.
(gdb) cd 
[17]+  Stopped                 gdb ./endtoendlib

编辑 1:

>>h5ls model_800000.h5 command gave me

conv1                    Group
conv2                    Group
forget_gate              Dataset 1, 250, 1, 1274
inception_3a             Group
inception_3b             Group
inception_4a             Group
inception_4b             Group
inception_4c             Group
inception_4d             Group
inception_4e             Group
inception_5a             Group
inception_5b             Group
input_gate               Dataset 1, 250, 1, 1274
input_value              Dataset 1, 250, 1, 1274
ip_bbox_unscaled0.p0     Dataset 4, 250
ip_bbox_unscaled0.p1     Dataset 4
ip_bbox_unscaled1.p0     Dataset 4, 250
ip_bbox_unscaled1.p1     Dataset 4
ip_bbox_unscaled2.p0     Dataset 4, 250
ip_bbox_unscaled2.p1     Dataset 4
ip_bbox_unscaled3.p0     Dataset 4, 250
ip_bbox_unscaled3.p1     Dataset 4
ip_bbox_unscaled4.p0     Dataset 4, 250
ip_bbox_unscaled4.p1     Dataset 4
ip_conf0.p0              Dataset 2, 250
ip_conf0.p1              Dataset 2
ip_conf1.p0              Dataset 2, 250
ip_conf1.p1              Dataset 2
ip_conf2.p0              Dataset 2, 250
ip_conf2.p1              Dataset 2
ip_conf3.p0              Dataset 2, 250
ip_conf3.p1              Dataset 2
ip_conf4.p0              Dataset 2, 250
ip_conf4.p1              Dataset 2
output_gate              Dataset 1, 250, 1, 1274
post_fc7_conv.p0         Dataset 1024, 1024, 1, 1
post_fc7_conv.p1         Dataset 1024

【问题讨论】:

是的,我检查了 .h5 文件,它是二进制格式。 但是二进制格式正确吗?如果你h5ls model_80000.h5 你会得到什么?看起来好像 caffe 期望文件有一个不存在的 'data' 数据集...... h5ls 命令在 EDIT1 中给了我。我看到了数据集。 你是如何保存这个model_80000.h5文件的?它是由咖啡保存的吗?将权重保存到 hdf5 文件中的 caffe net.ToHDF5(...) 方法将参数保存到 data(以及可选的 diff)数据集中。您的文件没有此数据集。您可以尝试手动“调整”文件并将data/ 添加到文件中的所有数据集... 【参考方案1】:

你考虑过net对象方法void CopyTrainedLayersFromHDF5(const string trained_filename);吗?它似乎可以满足您的需求。

至于"HDF5Data" 层:您在这里混淆了两件事。您拥有的 hdf5 文件存储了网络的训练参数。相比之下,"HDF5Data" 层存储了用于训练网络的训练示例

【讨论】:

是的,您的解决方案是正确的。我需要处理我的 h5 文件,以便可以将其加载到 c++ caffe。

以上是关于如何将经过训练的 caffe 模型以 h5 格式加载到 c++ caffe 网络?的主要内容,如果未能解决你的问题,请参考以下文章

具有更大输入图像尺寸的 Caffe 预训练模型

[caffe]Python加载训练caffe模型并进行测试2

将经过训练的 HDF5 模型加载到 Rust 中以进行预测

如何训练 ML 模型? [复制]

Caffe CNN 训练过程陷入循环

如何在 Keras 中保存经过训练的模型以在应用程序中使用它?