Tensorflow:在具有不同类别数量的新数据集上微调预训练模型

Posted

技术标签:

【中文标题】Tensorflow:在具有不同类别数量的新数据集上微调预训练模型【英文标题】:Tensorflow: Finetune pretrained model on new dataset with different number of classes 【发布时间】:2017-06-04 15:24:17 【问题描述】:

如何在新数据集上微调 TensorFlow 中的预训练模型?在 Caffe 中,我可以简单地重命名最后一层并为随机初始化设置一些参数。张量流中是否有类似的可能?

假设我有一个检查点文件 (deeplab_resnet.ckpt) 和一些设置计算图的代码,我可以在其中修改最后一层,使其具有与新数据集的类相同数量的输出。

然后我尝试像这样开始会话:

sess = tf.Session(config=config)
init = tf.initialize_all_variables()

sess.run(init)

trainable = tf.trainable_variables()
saver = tf.train.Saver(var_list=trainable, max_to_keep=40)
saver.restore(sess, 'ckpt_path/deeplab_resnet.ckpt')

但是,这在调用 saver.restore 函数时给了我一个错误,因为它需要与保存它的图形结构完全相同的图形结构。 我怎样才能只加载除'ckpt_path/deeplab_resnet.ckpt' 文件中的最后一层之外的所有权重? 我还尝试更改 Classification 图层名称,但也没有运气......

我正在使用tensorflow-deeplab-resnet model

【问题讨论】:

我不太明白你想要什么。你想修改一个层,即使用它但不同(改变形状等)还是你想使用所有模型但层(使用一个全新的层)。点滴 两者(通常在微调中完成)。我用一个旧模型替换最后一层,用一个适合新数据集的新类别数的模型。然后这最后一层需要随机初始化。我猜@Alexey Romanovs 的答案已经是解决方案的一半。唯一缺少的部分是当网络从以前的 caffemodel 导入时层的显式随机初始化,就像 tensorflow-deeplab-resnet 中的情况一样 @mcExchange,请用完整的解决方案填写下面的答案和/或批准下面的解决方案。 【参考方案1】:

您可以指定要恢复的变量的名称。

所以,你可以得到一个模型中所有变量的列表,并过滤掉最后一层的变量:

all_vars = tf.all_variables()
var_to_restore = [v for v in all_vars if not v.name.startswith('xxx')]

saver = tf.train.Saver(var_to_restore)

详情请参阅documentation。

或者,您可以尝试加载整个模型并在最后一层之前创建一个新的“分支”,并在训练期间在成本函数中使用它。

【讨论】:

v.name.startswith('xxx') 是一个很好的提示。尽管网络不再崩溃,但训练它还没有收敛/损失没有减少。我是否必须明确告诉网络随机初始化被遗漏的层? (顺便说一句,我不得不使用all_vars = tf.trainable_variables() 而不是all_vars = tf.all_variables() 您是否能够在不初始化新变量的情况下训练模型?如果您尝试使用已初始化的变量,TensorFlow 应该会给您一个错误。你可以尝试使用tf.variables_initializer(var_list),但奇怪的是TensorFlow允许你在不初始化所有变量的情况下训练模型。 这里的问题可能是TF图不是手动创建的,而是从caffemodel转换而来的。所以我没有像myVar = tf.Variable(tf.random_normal([...], stddev=...),name="...") 这样的显式变量定义。我可能必须通过它们的名称选择这些变量,并明确告诉 TF 随机初始化它们。你知道怎么做吗? 您确实有新图层的变量,对吗?对于他们,您可以tf.variables_initializer 来初始化您的新变量。 added_vars = [v for v in ...], init_op = tf.variables_initializer(added_vars), sees.run(init_op) 好吧,在恢复各个变量之前,我已经在执行 init = tf.initialize_all_variables()sess.run(init) 了。所以我想那时所有的变量都应该已经初始化了。我也有一种感觉,由于部分恢复,最后一层可能不会连接到网络的其余部分。至少在张量板中,上一层和最后一层之间没有线条......我会再次检查

以上是关于Tensorflow:在具有不同类别数量的新数据集上微调预训练模型的主要内容,如果未能解决你的问题,请参考以下文章

不同类别分类的数据集数量是不是重要

如何使用tensorflow为每个类获取具有相同数量图像的验证集?

如何对具有三个不同类别的 3 个圆形数据集执行光谱聚类

使用 tensorflow keras 预测 5 个不同类别的标签

分类不平衡对软件缺陷预测模型性能的影响研究(笔记)

测试和训练数据集具有不同数量的特征