Tensorflow加载预训练模型的特殊操作

Posted 走召大爷

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow加载预训练模型的特殊操作相关的知识,希望对你有一定的参考价值。

最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】

在前面的文章【Tensorflow加载预训练模型和保存模型】中介绍了如何保存训练好的模型,已经将预训练好的模型参数加载到当前网络。这些属于常规操作,即预训练的模型与当前网络结构的命名完全一致。

本文介绍一些不常规的操作:

  1. 如何只加载部分参数?
  2. 如何从两个模型中加载不同部分参数?
  3. 当预训练的模型的命名与当前定义的网络中的参数命名不一致时该怎么办?

1 只加载部分参数

举个例子,对已有的网络结构做了细微修改,例如只改了几层卷积通道数。如果从头训练显然没有finetune收敛速度快,但是模型又没法全部加载。此时,只需将未修改部分参数加载到当前网络即可。假设修改过的卷积层名称包含`conv_``,示例代码如下:

import tensorflow as tf
def restore(sess, ckpt_path):
	vars = tf.trainable_variables()
	vars = [v for v vars if not "conv_1" in v.name]
    saver = tf.train.Saver(var_list=vars)
	saver.restore(sess, ckpt_path)

2 从两个预训练模型中加载不同部分参数

如果需要从两个不同的预训练模型中加载不同部分参数,例如,网络中的前半部分用一个预训练模型参数,后半部分用另一个预训练模型中的参数,示例代码如下:

import tensorflow as tf
def restore(sess, ckpt_path):
	vars = tf.trainable_variables()
	model_1_vars = [v for v vars if "model_1" in v.name]
	model_2_vars = [v for v vars if "model_2" in v.name]
    saver_1 = tf.train.Saver(var_list=model_1_vars)
    saver_2 = tf.train.Saver(var_list=model_2_vars)
	saver_1 .restore(sess, ckpt_path)
	saver_2 .restore(sess, ckpt_path)

3 从参数名称不一致的模型中加载参数

举个例子,例如,预训练的模型所有的参数有个前缀name_1,现在定义的网络结构中的参数以name_2作为前缀。那么使用如下示例代码即可加载:

import tensorflow as tf
def restore(sess, ckpt_path):
	vars = tf.trainable_variables()
	vars_dict = dict()
	for v in vars:
	    key = v.name.split(':')[0]
	    if key.startswith("name_2/"):
	        key = key.replace("name_2/", "name_1/")
	    vars_dict[key] = v
	saver =tf.train.Saver(var_list=vars_dict)
	saver.restore(sess, ckpt_path)

注意: 使用上面代码时,要确保参数的shape一致,否则会无法加载参数。

如果不知道预训练的ckpt中参数名称,可以使用如下代码打印:

for name, shape in tf.train.list_variables(ckpt_path):
    print(name)

如果您觉得本文有帮助,辛苦您点个不需花钱的赞,您的举手之劳将对我提供了无限的写作动力! 也欢迎关注我的公众号:Python学习实战, 第一时间获取最新文章。

以上是关于Tensorflow加载预训练模型的特殊操作的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow加载预训练模型和保存模型

在 Java tensorflow v.1.2.0 中使用 Python tensorflow v.0.9.0 加载预训练模型

在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

当我尝试在jetson tx1中加载卷积预训练模型时,tensorflow中的错误被杀死

转 tensorflow模型保存 与 加载

求助 tensorflow怎样恢复预训练的模型啊