深度学习网络fine-tune原理研究
Posted Han Zheng, Practitioners and T
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习网络fine-tune原理研究相关的知识,希望对你有一定的参考价值。
一、什么是预训练模型(pre-trained model)
预训练模型就是已经用数据集训练好了的模型,这里的数据集一般指大型数据集。比如
- VGG16/19
- Resnet
- Imagenet
- COCO
正常情况下,在图像识别任务中常用的VGG16/19等网络是他人调试好的优秀网络,我们无需再修改其网络结构。
参考资料:
https://zhuanlan.zhihu.com/p/35890660 https://github.com/szagoruyko/loadcaffe
二、什么是模型微调
用一个单神经元网络解释模型微调的基本原理,
- Step1:假设我们的神经网络符合下面的形式:Y = W * X
- Step2:现在我们要找到一个W,使得当输入X=2时,输出Y=1,也就是希望W=0.5:1 = W * 2
- Step3:按照神经网络的基本训练过程,首先要对W进行初始化,初始化的值符合均值为0,方差为1的分布,假设W初始化为0.1:Y = 0.1 * X
- Step4:现在开始训练FP过程,当输入X=2时,W=0.1,输出Y=0.2,这个时候实际值和目标值1的误差为0.8:1 <====== 0.2 = 0.1 * 2
- Step5:开始BP反向传导,0.8的误差经过反向传播去更新权值W,假如这次更新为W=0.2,输出位0.4,与目标值的误差为0.6:1 <====== 0.4 = 0.2 * 2
- Step6:可能经过10次或20次BP反向传导,W终于得到了我们想要的0.5:Y = 0.5 * X
- Step7:如果最开始初始化的时候有人告诉你,W的值应该在0.47附近
- Step8:那么从最开始训练,你与目标值的误差就只有0.06了,那么可能只要一步两步BP,就能将W调整到0.5:1 <====== 0.94 = 0.47 * 2
Step7就相当于给你一个预训练模型(pre-trained model),Step8就是基于这个预训练模型去微调(fine-tune)。
可以看到,相对于从头开始训练,微调省去了大量计算资源和计算时间,提高了计算效率,甚至提高了准确率(因为在超大规模训练过程中,模型可能陷入局部次优空间中无法跳出,预训练相当于已经探好了最难的一部分路,后面的路下游模型走起来就轻松了)。
细心的读者可能会注意到,预训练模型对下游fine-tune任务效果的好坏,和以下几个因素有关:
- 预训练模型训练所用的语料和下游fine-tune任务的重合度:本质上,预训练模型的模型权重参数,代表的是喂入预训练模型的语料。如果预训练任何和下游fine-tune任务领域相差太大,则预训练模型的参数几乎不能起到提效的帮助,甚至可能帮倒忙。
- 预训练模型自身的容量:理论上,如果预训练模型足够大,能够包含下游任务的一部分核心部分,则预训练模型可以通过权重重调整,在fine-tune的过程中,激活一部分神经元以及关闭一部分神经元,以此使预训练模型朝着下游任务的方向去“生长”。
- 预训练模型使用的语料库是否足够大和种类丰富,因为这决定了预训练模型是否完成了足够的预训练,否则如果上游预训练模型没有完成收敛,接入下游fine-tune的时候,预训练模型也依然需要进行大量的微调,这对极大拖慢整体模型的收敛。反之,如果预训练模型已经基本完成了收敛,则对下游fine-tune训练的数据集要求就很小,fine-tune就可以基于一个小数据集依然可以得到较好的效果,同时也仅需要较少的训练时间。
- 预训练模型输入层的向量化方式、张量维度、嵌入方式、编码方位、shape维度等等,和下游fine-tune任务的这些参数结构是否完全一致(或者是否具备一定的迁移性),理论上说,输入层的结构是一种特征工程的经验形式,它本身也代表了模型对目标任务的某种抽象。打个比方,用于文本生成任务的模型,如果将一个像素图片“强行转换适配”输入进去,最终训练和预测的效果都不会好。
三、为什么要微调?
卷积神经网络的核心是:
- 浅层卷积层提取基础特征,比如边缘,轮廓等基础特征
- 深层卷积层提取抽象特征,比如整个脸型
- 全连接层根据特征组合进行评分分类
使用大型数据集训练的预训练模型,已经具备了提取浅层基础特征和深层抽象特征的能力。相比不做微调,这种方法具备以下优势:
- 避免了从头开始训练,减少了训练时间,节省了计算资源
- 避免了模型不收敛、参数不够优化、准确率低、模型泛化能力低、容易过拟合等问题
四、不同数据集下如何进行微调
数据集1:数据量少,但数据相似度非常高
在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。
数据集2:数据量少,数据相似度低
在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。
数据集3 - 数据量大,数据相似度低
在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)。
数据集4:数据量大,数据相似度高
这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。
五、微调指导事项
- 通常的做法是截断预先训练好的网络的最后一层(softmax层),并用与我们自己的问题相关的新的softmax层替换它。例如,ImageNet上预先训练好的网络带有1000个类别的softmax图层。如果我们的任务是对10个类别的分类,则网络的新softmax层将由10个类别组成,而不是1000个类别。然后,我们在网络上运行预先训练的权重。确保执行交叉验证,以便网络能够很好地推广。
- 使用较小的学习率来训练网络。由于我们预计预先训练的权重相对于随机初始化的权重已经相当不错,我们不想过快地扭曲它们太多。通常的做法是使初始学习率比用于从头开始训练(Training from scratch)的初始学习率小10倍。
- 如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。这是因为前几个图层捕捉了与我们的新问题相关的通用特征,如曲线和边。我们希望保持这些权重不变。相反,我们会让网络专注于学习后续深层中特定于数据集的特征。
六、通过卷积核可视化探究fine-tune本质
常见的预训练分类网络有牛津的VGG模型、谷歌的Inception模型、微软的ResNet模型等,他们都是预训练的用于分类和检测的卷积神经网络(CNN)。
本次选用的是VGG16模型,是一个在ImageNet数据集上预训练的模型,分类性能优秀,对其他数据集适应能力优秀。
0x1:直接基于VGG16进行手写数字预测
from tensorflow.keras.applications.vgg16 import VGG16 from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions import numpy as np model = VGG16(weights=\'imagenet\') img_path = \'6.webp\' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x) # decode the results into a list of tuples (class, description, probability) # (one such list for each sample in the batch) print(\'Predicted:\', decode_predictions(preds, top=3)[0])
输出结果:
Predicted: [(\'n03532672\', \'hook\', 0.4591384), (\'n02910353\', \'buckle\', 0.032941677), (\'n01930112\', \'nematode\', 0.032439113)]
可以看到,VGG16输出的最高概率预测结果是hook,很明显,VGG16的训练集并没有关于数字图片的样本。
0x2:通过手写数字,可视化VGG16各个层参数
from keras.models import Model from tensorflow.keras.applications.vgg16 import VGG16 from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions import numpy as np import cv2 import matplotlib.pyplot as plt def vis_conv(images, n, name, t): """visualize conv output and conv filter. Args: img: original image. n: number of col and row. t: vis type. name: save name. """ size = 64 margin = 5 if t == \'filter\': results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3)) if t == \'conv\': results = np.zeros((n * size + 7 * margin, n * size + 7 * margin)) for i in range(n): for j in range(n): if t == \'filter\': filter_img = images[i + (j * n)] if t == \'conv\': filter_img = images[..., i + (j * n)] filter_img = cv2.resize(filter_img, (size, size)) # Put the result in the square `(i, j)` of the results grid horizontal_start = i * size + i * margin horizontal_end = horizontal_start + size vertical_start = j * size + j * margin vertical_end = vertical_start + size if t == \'filter\': results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img if t == \'conv\': results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img # Display the results grid plt.imshow(results) plt.savefig(\'images/_.jpg\'.format(t, name), dpi=600) plt.show() def conv_output(model, layer_name, img): """Get the output of conv layer. Args: model: keras model. layer_name: name of layer in the model. img: processed input image. Returns: intermediate_output: feature map. """ # this is the placeholder for the input images input_img = model.input try: # this is the placeholder for the conv output out_conv = model.get_layer(layer_name).output except: raise Exception(\'Not layer named !\'.format(layer_name)) # get the intermediate layer model intermediate_layer_model = Model(inputs=input_img, outputs=out_conv) # get the output of intermediate layer model intermediate_output = intermediate_layer_model.predict(img) return intermediate_output[0] if __name__ == \'__main__\': model = VGG16(weights=\'imagenet\') img_path = \'6.webp\' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x) # decode the results into a list of tuples (class, description, probability) # (one such list for each sample in the batch) print(\'Predicted:\', decode_predictions(preds, top=3)[0]) conv_output_block1_conv1 = conv_output(model, "block1_conv1", x) print("block1_conv1: ", conv_output_block1_conv1) vis_conv(conv_output_block1_conv1, 8, "block1_conv1", \'conv\') conv_output_block1_conv2 = conv_output(model, "block1_conv2", x) print("block1_conv2: ", conv_output_block1_conv2) vis_conv(conv_output_block1_conv2, 8, "block1_conv2", \'conv\') conv_output_block2_conv1 = conv_output(model, "block2_conv1", x) print("block2_conv1: ", conv_output_block2_conv1) vis_conv(conv_output_block2_conv1, 8, "block2_conv1", \'conv\') conv_output_block2_conv2 = conv_output(model, "block2_conv2", x) print("block2_conv2: ", conv_output_block2_conv2) vis_conv(conv_output_block2_conv2, 8, "block2_conv2", \'conv\') conv_output_block3_conv1 = conv_output(model, "block3_conv1", x) print("block3_conv1: ", conv_output_block3_conv1) vis_conv(conv_output_block3_conv1, 8, "block3_conv1", \'conv\') conv_output_block3_conv2 = conv_output(model, "block3_conv2", x) print("block3_conv2: ", conv_output_block3_conv2) vis_conv(conv_output_block3_conv2, 8, "block3_conv2", \'conv\') conv_output_block5_conv3 = conv_output(model, "block5_conv3", x) print("block5_conv3: ", conv_output_block5_conv3) vis_conv(conv_output_block5_conv3, 8, "block5_conv3", \'conv\') print("fc1: ", conv_output(model, "fc1", x)) print("fc2: ", conv_output(model, "fc2", x)) print("predictions: ", conv_output(model, "predictions", x))
1/1 [==============================] - 2s 2s/step Predicted: [(\'n03532672\', \'hook\', 0.4591384), (\'n02910353\', \'buckle\', 0.032941677), (\'n01930112\', \'nematode\', 0.032439113)] 1/1 [==============================] - 0s 53ms/step block1_conv1: [[[ 0. 42.11969 0. ... 0. 32.04823 0. ] [ 0. 46.303555 82.50592 ... 0. 324.38284 164.56157 ] [ 0. 46.303555 82.50592 ... 0. 324.38284 164.56157 ] ... [ 0. 46.303555 82.50592 ... 0. 324.38284 164.56157 ] [ 0. 46.303555 82.50592 ... 0. 324.38284 164.56157 ] [ 2.61003 32.20762 173.75212 ... 0. 517.4678 391.77734 ]] [[ 0. 56.784718 0. ... 0. 0. 0. ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] ... [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [377.4901 38.781555 204.19121 ... 0. 382.94656 378.29724 ]] [[ 0. 56.784718 0. ... 0. 0. 0. ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] ... [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [377.4901 38.781555 204.19121 ... 0. 382.94656 378.29724 ]] ... [[ 0. 56.784718 0. ... 0. 0. 0. ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] ... [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [377.4901 38.781555 204.19121 ... 0. 382.94656 378.29724 ]] [[ 0. 56.784718 0. ... 0. 0. 0. ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] ... [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [ 2.401019 58.89123 63.275116 ... 0. 2.2158926 10.105784 ] [377.4901 38.781555 204.19121 ... 0. 382.94656 378.29724 ]] [[ 0. 39.011864 0. ... 0. 0. 0. ] [323.5314 39.00029 97.09346 ... 0. 0. 67.22728 ] [323.5314 39.00029 97.09346 ... 0. 0. 67.22728 ] ... [323.5314 39.00029 97.09346 ... 0. 0. 67.22728 ] [323.5314 39.00029 97.09346 ... 0. 0. 67.22728 ] [523.7337 25.070164 184.00014 ... 0. 144.83621 315.41928 ]]] 1/1 [==============================] - 0s 84ms/step block1_conv2: [[[982.48444 65.59724 0. ... 81.02978 698.99084 172.65338 ] [256.9937 101.16306 8.7225065 ... 203.38603 340.56735 0. ] [314.77548 126.94779 0. ... 159.34764 175.0137 0. ] ... [314.77548 126.94779 0. ... 159.34764 175.0137 0. ] [ 63.2487 0. 0. ... 125.09357 413.46884 33.402287 ] [ 0. 0. 0. ... 32.059208 0. 7.143284 ]] [[401.39062 97.3492 0. ... 134.1313 454.73416 0. ] [ 0. 97.926704 136.89134 ... 259.61768 632.9747 0. ] [ 0. 125.44156 95.91204 ... 174.20306 390.24847 0. ] ... [ 0. 125.44156 95.91204 ... 174.20306 390.24847 0. ] [ 0. 0. 109.98622 ... 103.348114 854.354 0. ] [ 0. 0. 0. ... 0. 394.38068 0. ]] [[396.95483 167.3767 0. ... 69.25613 207.11255 4.1853294] [ 0. 174.81584 76.58766 ... 161.11617 339.40433 0. ] [151.61284 87.23442 16.130083 ... 6.742235 1.1302795 0. ] ... [151.61284 87.23442 16.130083 ... 6.742235 1.1302795 0. ] [ 0. 0. 70.19446 ... 0. 479.9812 254.07501 ] [ 0. 0. 0. ... 0. 199.8518 50.87436 ]] ... [[396.95483 167.3767 0. ... 69.25613 207.11255 4.1853294] [ 0. 174.81584 76.58766 ... 161.11617 339.40433 0. ] [151.61284 87.23442 16.130083 ... 6.742235 1.1302795 0. ] ... [151.61284 87.23442 16.130083 ... 6.742235 1.1302795 0. ] [ 0. 0. 70.19446 ... 0. 479.9812 254.07501 ] [ 0. 0. 0. ... 0. 199.8518 50.87436 ]] [[196.74297 0. 0. ... 76.20704 371.12302 239.03537 ] [ 0. 0. 54.11582 ... 132.80391 642.51025 472.34528 ] [ 0. 0. 4.422485 ... 7.28855 283.40457 706.94666 ] ... [ 0. 0. 4.422485 ... 7.28855 283.40457 706.94666 ] [ 0. 0. 54.947617 ... 0. 688.73157 731.2318 ] [ 0. 0. 0. ... 0. 364.4021 284.65625 ]] [[ 0. 0. 0. ... 0. 0. 0. ] [ 0. 0. 0. ... 0. 407.869 0. ] [ 0. 0. 0. ... 0. 198.98882 101.46747 ] ... [ 0. 0. 0. ... 0. 198.98882 101.46747 ] [ 0. 0. 0. ... 0. 534.15466 69.81046 ] [287.62454 0. 0. ... 0. 764.0485 0. ]]] 1/1 [==============================] - 0s 76ms/step block2_conv1: [[[ 0. 0. 146.08685 ... 1138.9917 0. 1914.1439 ] [ 0. 0. 617.18994 ... 630.32166 0. 0. ] [ 0. 0. 479.59012 ... 803.52374 0. 281.59882 ] ... [ 0. 0. 479.59012 ... 803.52374 0. 281.59882 ] [ 0. 0. 583.4128 ... 895.7679 0. 715.7333 ] [ 0. 0. 1087.817 ... 2163.6226 0. 0. ]] [[ 0. 657.53296 0. ... 660.99 461.2479 1719.0864 ] [ 0. 823.556 349.60562 ... 0. 542.6992 0. ] [ 0. 748.83795 131.92645 ... 30.981398 517.1108 82.481895] ... [ 0. 748.83795 131.92645 ... 30.981398 517.1108 82.481895] [ 0. 826.5497 252.64777 ... 64.045074 392.9257 619.41876 ] [ 0. 693.9135 1073.2073 ... 1989.0895 697.90814 0. ]] [[ 0. 239.73143 0. ... 901.56885 274.7921 1343.2406 ] [ 0. 214.44774 181.45721 ... 0. 279.94656 0. ] [ 0. 130.28665 0. ... 90.52182 205.50911 130.00967 ] ... [ 0. 130.28665 0. ... 90.52182 205.50911 130.00967 ] [ 0. 230.28584 60.274647 ... 54.528107 35.845345 758.34717 ] [ 0. 283.4764 837.31805 ... 1669.6423 417.16782 390.9171 ]] ... [[ 0. 239.73143 0. ... 901.56885 274.7921 1343.2406 ] [ 0. 214.44774 181.45721 ... 0. 279.94656 0. ] [ 0. 130.28665 0. ... 90.52182 205.50911 130.00967 ] ... [ 0. 130.28665 0. ... 90.52182 205.50911 130.00967 ] [ 0. 230.28584 60.274647 ... 54.528107 35.845345 758.34717 ] [ 0. 283.4764 837.31805 ... 1669.6423 417.16782 390.9171 ]] [[ 0. 149.2003 0. ... 467.1346 130.91127 1713.3496 ] [ 0. 89.11 283.70944 ... 0. 236.00652 0. ] [ 0. 21.128517 52.216312 ... 0. 233.49413 93.75622 ] ... [ 0. 21.128517 52.216312 ... 0. 233.49413 93.75622 ] [ 0. 120.84711 171.13362 ... 0. 73.68687 632.3945 ] [ 0. 207.82211 976.44196 ... 1907.8083 525.08185 29.64562 ]] [[ 0. 296.92758 171.61426 ... 975.3303 292.51434 1616.5455 ] [ 0. 235.07794 710.6981 ... 276.39038 0. 0. ] [ 0. 116.03024 512.0845 ... 650.45764 53.27237 331.76382 ] ... [ 0. 116.03024 512.0845 ... 650.45764 53.27237 331.76382 ] [ 0. 247.85234 603.1937 ... 753.06476 57.02111 653.146 ] [ 0. 435.59036 1229.345 ... 2149.0642 365.4059 0. ]]] WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c074790> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. 1/1 [==============================] - 0s 96ms/step block2_conv2: [[[ 19.134865 65.2908 388.85107 ... 77.567345 0. 0. ] [ 385.78787 0. 83.92136 ... 823.738 0. 0. ] [ 362.76718 0. 0. ... 770.1545 0. 0. ] ... [ 370.19595 0. 0. ... 693.7316 0. 0. ] [ 395.07098 1163.4445 0. ... 685.89105 0. 0. ] [ 393.64594 221.8914 0. ... 779.5206 0. 0. ]] [[ 0. 0. 658.96985 ... 266.29254 1334.6693 0. ] [ 175.15945 0. 0. ... 927.1358 410.14014 0. ] [ 113.65867 0. 0. ... 705.73663 115.82475 341.95673 ] ... [ 89.81759 278.56213 0. ... 651.8543 775.20416 502.7654 ] [ 136.82233 1937.8406 0. ... 647.9445 302.8629 525.4279 ] [ 262.19644 357.42938 0. ... 750.1874 0. 489.33453 ]] [[ 0. 0. 418.21606 ... 12.688118 795.45483 0. ] [ 234.67218 0. 0. ... 426.10312 0. 0. ] [ 145.08507 0. 0. ... 287.3707 0. 296.64294 ] ... [ 103.087685 305.11697 62.120567 ... 267.3017 545.9968 524.84625 ] [ 235.22937 2067.736 239.66722 ... 172.1788 407.2032 489.35236 ] [ 323.7679 407.43408 319.0578 ... 341.47412 0. 345.82104 ]] ... [[ 0. 0. 580.24994 ... 68.54731 589.51636 0. ] [ 201.64163 0. 157.14062 ... 501.0832 0. 0. ] [ 133.07848 0. 0. ... 351.53003 0. 415.4161 ] ... [ 86.24023 465.5442 22.741163 ... 337.74213 215.66536 622.05804 ] [ 174.42499 2174.4937 46.142918 ... 286.23798 212.43034 572.5916 ] [ 282.7715 504.28677 132.34572 ... 501.6414 0. 371.98062 ]] [[ 0. 0. 247.89134 ... 337.7562 870.8283 0. ] [ 129.28552 0. 0. ... 976.0519 0. 0. ] [ 0. 107.290855 0. ... 696.99493 0. 248.08282 ] ... [ 0. 545.71716 0. ... 687.88995 175.53624 456.3958 ] [ 40.394768 2056.4695 0. ... 716.48956 157.10045 438.98425 ] [ 169.84534 324.61182 357.57187 ... 724.79034 0. 279.55737 ]] [[ 0. 0. 0. ... 108.35586 1594.9191 0. ] [ 0. 0. 0. ... 641.5959 631.3734 0. ] [ 0. 0. 0. ... 476.7445 236.77658 0. ] ... [ 0. 0. 0. ... 514.81213 659.1744 0. ] [ 0. 558.51337 0. ... 529.3481 646.179 0. ] [ 0. 0. 318.3686 ... 567.25116 0. 85.41164 ]]] WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c0250d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. 1/1 [==============================] - 0s 86ms/step block3_conv1: [[[ 104.03467 0. 7676.7437 ... 284.6595 104.21471 495.0637 ] [ 0. 313.47745 5637.235 ... 773.39124 312.671 710.7272 ] [ 0. 626.0542 4799.9775 ... 797.72 329.52908 588.2553 ] ... [ 0. 646.8998 4819.6846 ... 770.5025 316.77924 555.9198 ] [ 0. 247.11465 5635.4976 ... 528.78986 281.3929 570.0344 ] [ 30.971907 0. 7807.489 ... 149.22829 247.03853 569.7642 ]] [[ 0. 871.3891 5385.873 ... 138.57967 142.74121 983.03674 ] [ 0. 1012.0134 499.58597 ... 162.09428 256.54013 1158.3336 ] [ 0. 1021.0573 28.230726 ... 184.61717 219.79193 785.92285 ] ... [ 0. 1050.2477 0. ... 146.40399 266.05975 744.28723 ] [ 0. 998.98596 374.99014 ... 64.251434 274.85852 940.72205 ] [ 0. 715.7788 5695.999 ... 181.9697 113.495964 998.7362 ]] [[ 0. 715.5003 4931.604 ... 174.02486 218.0967 733.6579 ] [ 0. 782.053 98.879425 ... 201.88213 215.22943 785.8126 ] [ 0. 754.37915 0. ... 156.65364 84.32829 489.68857 ] ... [ 0. 784.71075 0. ... 119.22443 137.86731 454.97656 ] [ 0. 741.68567 0. ... 42.16644 243.78513 592.8224 ] [ 0. 564.874 5148.9604 ... 128.61302 147.20853 733.89886 ]] ... [[ 0. 496.68298 4885.435 ... 318.65524 245.03665 575.7172 ] [ 0. 486.5161 27.83389 ... 512.1368 232.01933 566.13635 ] [ 0. 477.45157 0. ... 499.04877 68.24934 263.7914 ] ... [ 0. 499.77722 0. ... 459.50702 132.83049 226.18076 ] [ 0. 439.5999 0. ... 320.2604 207.68942 371.76605 ] [ 0. 339.14404 5100.4336 ... 253.7242 112.67809 590.3231 ]] [[ 0. 347.25443 5573.7017 ... 627.8705 275.148 631.8805 ] [ 0. 358.82916 292.3079 ... 979.4485 303.31757 662.5002 ] [ 0. 478.66336 0. ... 1011.04913 144.6257 358.16284 ] ... [ 0. 500.74857 0. ... 972.3128 223.55475 336.5134 ] [ 0. 355.48328 104.18472 ... 832.22375 270.79025 496.9038 ] [ 0. 219.11375 5712.4497 ... 539.98773 84.06546 667.78613 ]] [[ 0. 604.2773 7762.388 ... 492.06854 294.44586 373.23422 ] [ 0. 660.0235 5493.3257 ... 210.03978 176.89102 304.05936 ] [ 0. 675.077 4603.5874 ... 169.29701 125.09003 53.69849 ] ... [ 0. 701.2141 4594.911 ... 142.22992 227.38722 59.698753] [ 0. 718.4968 5527.42 ... 161.2458 129.69702 249.47922 ] [ 0. 586.12274 8277.507 ... 435.10352 0. 348.29013 ]]] 1/1 [==============================] - 0s 105ms/step block3_conv2: [[[ 0. 971.66376 794.8841 ... 172.1506 10.597431 794.8708 ] [ 0. 291.7925 826.4213 ... 39.319454 0. 718.3281 ] [ 0. 156.54356 802.0568 ... 0. 0. 503.39447 ] ... [ 0. 401.88135 1241.3585 ... 0. 0. 362.15497 ] [ 0. 675.3719 1448.097 ... 0. 9.820769 410.58932 ] [ 0. 10.890532 953.4981 ... 233.22906 0. 579.7396 ]] [[ 575.767 1863.7603 592.8948 ... 245.05453 0. 1068.8091 ] [ 514.8801 844.0041 222.19751 ... 0. 0. 788.1397 ] [ 19.14704 444.27817 111.57798 ... 0. 0. 409.57492 ] ... [ 252.99167 848.908 513.1679 ... 0. 0. 312.90305 ] [ 591.92786 1448.2924 630.19824 ... 0. 0. 504.8597 ] [ 0. 379.8196 763.054 ... 0. 72.78092 733.65424 ]] [[ 287.43423 1910.7128 349.80966 ... 387.3527 0. 1265.1278 ] [ 0. 740.3088 124.85873 ... 0. 0. 918.3699 ] [ 0. 286.83832 118.424774 ... 0. 177.10791 486.00412 ] ... [ 0. 735.8566 558.8175 ... 0. 193.26689 449.90454 ] [ 53.59411 1525.2466 651.7935 ... 0. 103.276146 716.995 ] [ 0. 603.8922 836.88104 ... 50.30762 191.5637 884.57367 ]] ... [[ 292.4923 1834.398 444.55945 ... 540.1754 14.972595 1457.0437 ] [ 0. 642.0181 319.91138 ... 44.719204 156.22743 1106.5459 ] [ 0. 170.11359 338.21768 ... 0. 376.42972 603.82666 ] ... [ 0. 581.471 737.77203 ... 0. 400.47705 579.5313 ] [ 0. 1367.3385 798.4122 ... 0. 260.49323 826.0262 ] [ 0. 544.79816 826.0728 ... 77.14375 283.54224 990.2182 ]] [[ 584.30676 1950.905 596.8577 ... 740.97327 81.50432 1820.6097 ] [ 116.4952 835.781 588.2435 ... 225.01852 196.70117 1720.8013 ] [ 0. 356.6451 615.6922 ... 77.022446 354.97284 1198.3191 ] ... [ 0. 733.50824 1012.90985 ... 0. 296.32776 1099.1088 ] [ 186.0945 1339.9901 1179.6779 ... 0. 191.2773 1315.7777 ] [ 0. 384.91098 1044.8905 ... 228.41646 209.99303 1404.8423 ]] [[ 608.40894 1603.2566 899.59283 ... 999.1029 64.82636 1448.8973 ] [ 733.8801 1092.808 762.7826 ... 444.4963 137.76027 1666.6692 ] [ 293.81265 823.8305 784.1011 ... 267.9691 135.08733 1363.0045 ] ... [ 359.85425 1058.6151 1013.9297 ... 163.37076 159.4037 1266.5629 ] [ 682.56195 1274.0765 1125.9093 ... 177.28194 135.8132 1424.4539 ] [ 148.46483 454.86966 954.3874 ... 199.56137 320.5976 1351.9453 ]]] 1/1 [==============================] - 0s 146ms/step block5_conv3: [[[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] ... [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ]] [[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] ... [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ]] [[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] ... [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ]] ... [[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 1.2440066 0. ] ... [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ]] [[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] ... [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ]] [[0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0. 0. ] [0. 0. 0. ... 0. 0.5987776 0. ] ... [0. 0. 0. ... 0深度学习前沿应用图像分类Fine-Tuning
【深度学习前沿应用】图像分类Fine-Tuning
(文章目录)
前言
1. 什么是预训练-微调模式?
在计算机视觉领域,预训练-微调模式已经沿用了多年,即在大规模图片数据集预训练模型参数,然后将训练好的参数在新的小数据集任务上进行微调,从而产生泛化性能更好的模型。
2. 什么是ResNet?
ResNet为常用的预训练模型之一,其核心操作为卷积与残差连接。卷积层为3×3的滤波器,并遵循两个简单的设计规则:①对于相同的输出特征图尺寸,每层具有相同数量的滤波器;②如果特征图尺寸减半,则滤波器数量加倍,以保持每层的时间复杂度。直接用步长为2的卷积层进行下采样,网络以全局平均池化层和伴随softmax的1000维全连接层结束,其中,卷积层数为34,因此也称为ResNet34(如下图1所示)。
本小节将使用ResNet34预训练-微调框架,实现猫脸12分类。对于给定的猫脸,判断其所属类型。
一、数据加载及预处理
本实验数据集来源于网络开源数据集(https://aistudio.baidu.com/aistudio/datasetdetail/10954),该数据集中包含12类猫图片,总计数据量为2160,部分图片展示如下图1所示。
(一)、数据加载及预处理
首先将该数据集挂载到当前项目中,然后读取数据文件,将数据按照8:2划分为训练集与验证集
- 导入相关包
import os import time import os.path as osp import zipfile import numpy as np import paddle import paddle.nn as nn import pandas as pd import paddle.nn.functional as F from PIL import Image from paddle.io import Dataset, DataLoader from paddle.optimizer import Adam from paddle.vision import Compose, ToTensor, Resize from paddle.vision.models import resnet34 from paddle.metric import Accuracy from sklearn.model_selection import StratifiedShuffleSplit
- 将train划分为训练集和验证集
info = pd.read_csv(osp.join(./data, train_list.txt), sep=\\t, header=None) images, labels = info.iloc[:, 0], info.iloc[:, 1] split = StratifiedShuffleSplit(test_size=0.2) train_idx, valid_idx = next(split.split(images, labels)) info_tr = info.iloc[train_idx, :] info_va = info.iloc[valid_idx, :] info_tr.to_csv(data/train.csv, header=False, index=False) info_va.to_csv(data/valid.csv, header=False, index=False)
(二)、数据集封装
class CatDataset(Dataset): train_file = cat_12_train.zip test_file = cat_12_test.zip train_label = train_list.txt def __init__(self, root, mode, transform=None): super(CatDataset, self).__init__() self.root = root self.mode = mode self.transform = transform if not osp.isfile(osp.join(root, self.train_file)) or \\ not osp.isfile(osp.join(root, self.train_label)) or \\ not osp.isfile(osp.join(root, self.test_file)): raise ValueError(wrong data path) if not osp.isdir(osp.join(self.root, cat_12_train)): with zipfile.ZipFile(osp.join(root, self.train_file)) as f: f.extractall(root) with zipfile.ZipFile(osp.join(root, self.test_file)) as f: f.extractall(root) if mode == train: info = pd.read_csv(osp.join(root, train_list.txt), sep=\\t, header=None) self.images = info.iloc[:, 0].to_list() self.labels = paddle.to_tensor( info.iloc[:, 1].to_list() ) elif mode == train_: info = pd.read_csv(osp.join(root, train.csv), header=None) self.images = info.iloc[:, 0].to_list() self.labels = paddle.to_tensor( info.iloc[:, 1].to_list() ) pass elif mode == valid_: info = pd.read_csv(osp.join(root, valid.csv), header=None) self.images = info.iloc[:, 0].to_list() self.labels = paddle.to_tensor( info.iloc[:, 1].to_list() ) else: images = os.listdir(os.path.join(root, cat_12_test)) self.images = [cat_12_test/+image for image in images] self.labels = None def __getitem__(self, idx): image = Image.open(osp.join(self.root, self.images[idx])) if image.mode != RGB: image = image.convert(RGB) if self.transform is not None: image = self.transform(image) if self.mode == test: return image, else: label = self.labels[idx] return image, label def __len__(self): return len(self.images)
(三)、样本分类与统计
paddle.set_device(gpu if paddle.is_compiled_with_cuda() else cpu) transform = Compose([ Resize([224, 224]), ToTensor() ]) train_ds = CatDataset(./data, train_, transform) valid_ds = CatDataset(./data, valid_, transform) train_dl = DataLoader(train_ds, batch_size=64, shuffle=True) valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=False) print(训练集样本数:,train_ds.__len__()) print(验证集样本数:,valid_ds.__len__())
二、预训练模型加载
paddle,vision是飞桨在视觉领域的高层API,内部封装了常用的数据集以及常用预测训练模型,如LeNet、VGG系列、ResNet系列及MobileNet系列等。本实验使用resnet34为例,演示如何进行图像分类的微调。
准备好数据集之后,加载预训练模型,调用net=resnet34(pretrained=True),设置参数pretrained为True,便可使用预训练好的参数,否则,需要从头开始训练参数(首次加载预训练参数时需要从相关专业网络中下载):
加载预训练模型,并设置类别数目为12(猫的分类)
net = resnet34(pretrained=True, num_classes=12)
三、模型微调
加载好预训练的模型之后,定义模型的优化器、评价指标等,输入领域数据,执行微调:
(一)、定义优化器
optimizer = Adam( parameters=net.parameters(), learning_rate=1e-5 )
(二)、定义损失函数
loss_fn = nn.CrossEntropyLoss()
(三)、定义准确率评价指标
metric_fn = Accuracy()
(四)、微调20轮
for epoch in range(20): t0 = time.time() net.train() for data, label in train_dl: logit = net(data) loss = loss_fn(logit, label.astype(int64)) optimizer.clear_grad() loss.backward() optimizer.step() # 验证 net.eval() loss_tr = 0. for data, label in train_dl: logit = net(data) label = label.astype(int64) loss_tr += loss_fn(logit, label).cpu().numpy()[0] loss_tr /= len(train_dl) loss_va = 0. for data, label in valid_dl: label = label.astype(int64) logit = net(data) loss_va += loss_fn(logit, label).cpu().numpy()[0] metric_fn.update( metric_fn.compute(logit, label) ) loss_va /= len(valid_dl) acc_va = metric_fn.accumulate() metric_fn.reset() t = time.time() - t0 print([Epoch :3d :.2fs] train loss(:.4f); valid loss(:.4f), acc(:.2f) .format(epoch, t, loss_tr, loss_va, acc_va))
训练过程部分输出如下图2所示:
四、模型预测
import matplotlib.image as mpimg import matplotlib.pyplot as plt def show_image(file_name): img = mpimg.imread(data/+file_name) plt.figure(figsize=(10,10)) plt.imshow(img) plt.show() test_ds = CatDataset(./data, mode=test, transform=transform) test_dl = DataLoader(test_ds, batch_size=32, shuffle=False) test_pred = [] with paddle.no_grad(): for data, in test_dl: logit = net(data) pred = paddle.argmax( F.softmax(logit, axis=-1), axis=-1 ) test_pred.append(pred.cpu().numpy()) test_pred = np.concatenate(test_pred, axis=0) for image, pred in zip(test_ds.images, test_pred.astype(np.int)): img = mpimg.imread(data/+image) plt.figure(figsize=(10,10)) plt.imshow(img) plt.show() print(图片路径:%s, 图片预测类型:%d\\n % (image.split(/)[1], pred))
预测结果部分输出如下图3、4、5、6所示
总结
本系列文章内容为根据清华社出版的《机器学习实践》所作的相关笔记和感悟,其中代码均为基于百度飞桨开发,若有任何侵权和不妥之处,请私信于我,定积极配合处理,看到必回!!!
最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:
【**学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。**】
以上是关于深度学习网络fine-tune原理研究的主要内容,如果未能解决你的问题,请参考以下文章