Tensorflow:在具有不同类别数量的新数据集上微调预训练模型
Posted
技术标签:
【中文标题】Tensorflow:在具有不同类别数量的新数据集上微调预训练模型【英文标题】:Tensorflow: Finetune pretrained model on new dataset with different number of classes 【发布时间】:2017-06-04 15:24:17 【问题描述】:如何在新数据集上微调 TensorFlow 中的预训练模型?在 Caffe 中,我可以简单地重命名最后一层并为随机初始化设置一些参数。张量流中是否有类似的可能?
假设我有一个检查点文件 (deeplab_resnet.ckpt
) 和一些设置计算图的代码,我可以在其中修改最后一层,使其具有与新数据集的类相同数量的输出。
然后我尝试像这样开始会话:
sess = tf.Session(config=config)
init = tf.initialize_all_variables()
sess.run(init)
trainable = tf.trainable_variables()
saver = tf.train.Saver(var_list=trainable, max_to_keep=40)
saver.restore(sess, 'ckpt_path/deeplab_resnet.ckpt')
但是,这在调用 saver.restore
函数时给了我一个错误,因为它需要与保存它的图形结构完全相同的图形结构。
我怎样才能只加载除'ckpt_path/deeplab_resnet.ckpt'
文件中的最后一层之外的所有权重?
我还尝试更改 Classification
图层名称,但也没有运气......
我正在使用tensorflow-deeplab-resnet model
【问题讨论】:
我不太明白你想要什么。你想修改一个层,即使用它但不同(改变形状等)还是你想使用所有模型但层(使用一个全新的层)。点滴 两者(通常在微调中完成)。我用一个旧模型替换最后一层,用一个适合新数据集的新类别数的模型。然后这最后一层需要随机初始化。我猜@Alexey Romanovs 的答案已经是解决方案的一半。唯一缺少的部分是当网络从以前的 caffemodel 导入时层的显式随机初始化,就像tensorflow-deeplab-resnet
中的情况一样
@mcExchange,请用完整的解决方案填写下面的答案和/或批准下面的解决方案。
【参考方案1】:
您可以指定要恢复的变量的名称。
所以,你可以得到一个模型中所有变量的列表,并过滤掉最后一层的变量:
all_vars = tf.all_variables()
var_to_restore = [v for v in all_vars if not v.name.startswith('xxx')]
saver = tf.train.Saver(var_to_restore)
详情请参阅documentation。
或者,您可以尝试加载整个模型并在最后一层之前创建一个新的“分支”,并在训练期间在成本函数中使用它。
【讨论】:
v.name.startswith('xxx')
是一个很好的提示。尽管网络不再崩溃,但训练它还没有收敛/损失没有减少。我是否必须明确告诉网络随机初始化被遗漏的层? (顺便说一句,我不得不使用all_vars = tf.trainable_variables()
而不是all_vars = tf.all_variables()
您是否能够在不初始化新变量的情况下训练模型?如果您尝试使用已初始化的变量,TensorFlow 应该会给您一个错误。你可以尝试使用tf.variables_initializer(var_list)
,但奇怪的是TensorFlow允许你在不初始化所有变量的情况下训练模型。
这里的问题可能是TF图不是手动创建的,而是从caffemodel转换而来的。所以我没有像myVar = tf.Variable(tf.random_normal([...], stddev=...),name="...")
这样的显式变量定义。我可能必须通过它们的名称选择这些变量,并明确告诉 TF 随机初始化它们。你知道怎么做吗?
您确实有新图层的变量,对吗?对于他们,您可以tf.variables_initializer
来初始化您的新变量。 added_vars = [v for v in ...], init_op = tf.variables_initializer(added_vars), sees.run(init_op)
好吧,在恢复各个变量之前,我已经在执行 init = tf.initialize_all_variables()
和 sess.run(init)
了。所以我想那时所有的变量都应该已经初始化了。我也有一种感觉,由于部分恢复,最后一层可能不会连接到网络的其余部分。至少在张量板中,上一层和最后一层之间没有线条......我会再次检查以上是关于Tensorflow:在具有不同类别数量的新数据集上微调预训练模型的主要内容,如果未能解决你的问题,请参考以下文章
如何使用tensorflow为每个类获取具有相同数量图像的验证集?