深度学习网络fine-tune原理研究

Posted Han Zheng, Practitioners and T

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习网络fine-tune原理研究相关的知识,希望对你有一定的参考价值。

深度学习网络fine-tune原理研究 - 以卷积神经网络为例

一、什么是预训练模型(pre-trained model)

预训练模型就是已经用数据集训练好了的模型,这里的数据集一般指大型数据集。比如

  • VGG16/19
  • Resnet
  • Imagenet
  • COCO

正常情况下,在图像识别任务中常用的VGG16/19等网络是他人调试好的优秀网络,我们无需再修改其网络结构。

参考资料:

https://zhuanlan.zhihu.com/p/35890660
https://github.com/szagoruyko/loadcaffe

 

二、什么是模型微调

用一个单神经元网络解释模型微调的基本原理,

  • Step1:假设我们的神经网络符合下面的形式:Y = W * X
  • Step2:现在我们要找到一个W,使得当输入X=2时,输出Y=1,也就是希望W=0.5:1 = W * 2
  • Step3:按照神经网络的基本训练过程,首先要对W进行初始化,初始化的值符合均值为0,方差为1的分布,假设W初始化为0.1:Y = 0.1 * X
  • Step4:现在开始训练FP过程,当输入X=2时,W=0.1,输出Y=0.2,这个时候实际值和目标值1的误差为0.8:1 <====== 0.2 = 0.1 * 2
  • Step5:开始BP反向传导,0.8的误差经过反向传播去更新权值W,假如这次更新为W=0.2,输出位0.4,与目标值的误差为0.6:1 <====== 0.4 = 0.2 * 2
  • Step6:可能经过10次或20次BP反向传导,W终于得到了我们想要的0.5:Y = 0.5 * X
  • Step7:如果最开始初始化的时候有人告诉你,W的值应该在0.47附近
  • Step8:那么从最开始训练,你与目标值的误差就只有0.06了,那么可能只要一步两步BP,就能将W调整到0.5:1 <====== 0.94 = 0.47 * 2

Step7就相当于给你一个预训练模型(pre-trained model),Step8就是基于这个预训练模型去微调(fine-tune)。

可以看到,相对于从头开始训练,微调省去了大量计算资源和计算时间,提高了计算效率,甚至提高了准确率(因为在超大规模训练过程中,模型可能陷入局部次优空间中无法跳出,预训练相当于已经探好了最难的一部分路,后面的路下游模型走起来就轻松了)。

细心的读者可能会注意到,预训练模型对下游fine-tune任务效果的好坏,和以下几个因素有关

  • 预训练模型训练所用的语料和下游fine-tune任务的重合度:本质上,预训练模型的模型权重参数,代表的是喂入预训练模型的语料。如果预训练任何和下游fine-tune任务领域相差太大,则预训练模型的参数几乎不能起到提效的帮助,甚至可能帮倒忙。
  • 预训练模型自身的容量:理论上,如果预训练模型足够大,能够包含下游任务的一部分核心部分,则预训练模型可以通过权重重调整,在fine-tune的过程中,激活一部分神经元以及关闭一部分神经元,以此使预训练模型朝着下游任务的方向去“生长”。
  • 预训练模型使用的语料库是否足够大和种类丰富,因为这决定了预训练模型是否完成了足够的预训练,否则如果上游预训练模型没有完成收敛,接入下游fine-tune的时候,预训练模型也依然需要进行大量的微调,这对极大拖慢整体模型的收敛。反之,如果预训练模型已经基本完成了收敛,则对下游fine-tune训练的数据集要求就很小,fine-tune就可以基于一个小数据集依然可以得到较好的效果,同时也仅需要较少的训练时间。
  • 预训练模型输入层的向量化方式、张量维度、嵌入方式、编码方位、shape维度等等,和下游fine-tune任务的这些参数结构是否完全一致(或者是否具备一定的迁移性),理论上说,输入层的结构是一种特征工程的经验形式,它本身也代表了模型对目标任务的某种抽象。打个比方,用于文本生成任务的模型,如果将一个像素图片“强行转换适配”输入进去,最终训练和预测的效果都不会好。

 

三、为什么要微调?

卷积神经网络的核心是:

  • 浅层卷积层提取基础特征,比如边缘,轮廓等基础特征
  • 深层卷积层提取抽象特征,比如整个脸型
  • 全连接层根据特征组合进行评分分类

使用大型数据集训练的预训练模型,已经具备了提取浅层基础特征和深层抽象特征的能力。相比不做微调,这种方法具备以下优势:

  • 避免了从头开始训练,减少了训练时间,节省了计算资源
  • 避免了模型不收敛、参数不够优化、准确率低、模型泛化能力低、容易过拟合等问题

 

四、不同数据集下如何进行微调

数据集1:数据量少,但数据相似度非常高

在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。

数据集2:数据量少,数据相似度低

在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。

数据集3 - 数据量大,数据相似度低

在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)。

数据集4:数据量大,数据相似度高

这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

 

五、微调指导事项

  • 通常的做法是截断预先训练好的网络的最后一层(softmax层),并用与我们自己的问题相关的新的softmax层替换它。例如,ImageNet上预先训练好的网络带有1000个类别的softmax图层。如果我们的任务是对10个类别的分类,则网络的新softmax层将由10个类别组成,而不是1000个类别。然后,我们在网络上运行预先训练的权重。确保执行交叉验证,以便网络能够很好地推广。
  • 使用较小的学习率来训练网络。由于我们预计预先训练的权重相对于随机初始化的权重已经相当不错,我们不想过快地扭曲它们太多。通常的做法是使初始学习率比用于从头开始训练(Training from scratch)的初始学习率小10倍。
  • 如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。这是因为前几个图层捕捉了与我们的新问题相关的通用特征,如曲线和边。我们希望保持这些权重不变。相反,我们会让网络专注于学习后续深层中特定于数据集的特征。

 

六、通过卷积核可视化探究fine-tune本质 

常见的预训练分类网络有牛津的VGG模型、谷歌的Inception模型、微软的ResNet模型等,他们都是预训练的用于分类和检测的卷积神经网络(CNN)。

本次选用的是VGG16模型,是一个在ImageNet数据集上预训练的模型,分类性能优秀,对其他数据集适应能力优秀。

0x1:直接基于VGG16进行手写数字预测

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np

model = VGG16(weights=\'imagenet\')

img_path = \'6.webp\'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print(\'Predicted:\', decode_predictions(preds, top=3)[0])
6.webp

输出结果:

Predicted: [(\'n03532672\', \'hook\', 0.4591384), (\'n02910353\', \'buckle\', 0.032941677), (\'n01930112\', \'nematode\', 0.032439113)]

可以看到,VGG16输出的最高概率预测结果是hook,很明显,VGG16的训练集并没有关于数字图片的样本。

0x2:通过手写数字,可视化VGG16各个层参数

from keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np
import cv2
import matplotlib.pyplot as plt


def vis_conv(images, n, name, t):
    """visualize conv output and conv filter.
    Args:
           img: original image.
           n: number of col and row.
           t: vis type.
           name: save name.
    """
    size = 64
    margin = 5

    if t == \'filter\':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3))
    if t == \'conv\':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin))

    for i in range(n):
        for j in range(n):
            if t == \'filter\':
                filter_img = images[i + (j * n)]
            if t == \'conv\':
                filter_img = images[..., i + (j * n)]
            filter_img = cv2.resize(filter_img, (size, size))

            # Put the result in the square `(i, j)` of the results grid
            horizontal_start = i * size + i * margin
            horizontal_end = horizontal_start + size
            vertical_start = j * size + j * margin
            vertical_end = vertical_start + size
            if t == \'filter\':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img
            if t == \'conv\':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img

    # Display the results grid
    plt.imshow(results)
    plt.savefig(\'images/_.jpg\'.format(t, name), dpi=600)
    plt.show()


def conv_output(model, layer_name, img):
    """Get the output of conv layer.

    Args:
           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

    Returns:
           intermediate_output: feature map.
    """
    # this is the placeholder for the input images
    input_img = model.input

    try:
        # this is the placeholder for the conv output
        out_conv = model.get_layer(layer_name).output
    except:
        raise Exception(\'Not layer named !\'.format(layer_name))

    # get the intermediate layer model
    intermediate_layer_model = Model(inputs=input_img, outputs=out_conv)

    # get the output of intermediate layer model
    intermediate_output = intermediate_layer_model.predict(img)

    return intermediate_output[0]


if __name__ == \'__main__\':
    model = VGG16(weights=\'imagenet\')

    img_path = \'6.webp\'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    preds = model.predict(x)
    # decode the results into a list of tuples (class, description, probability)
    # (one such list for each sample in the batch)
    print(\'Predicted:\', decode_predictions(preds, top=3)[0])

    conv_output_block1_conv1 = conv_output(model, "block1_conv1", x)
    print("block1_conv1: ", conv_output_block1_conv1)
    vis_conv(conv_output_block1_conv1, 8, "block1_conv1", \'conv\')

    conv_output_block1_conv2 = conv_output(model, "block1_conv2", x)
    print("block1_conv2: ", conv_output_block1_conv2)
    vis_conv(conv_output_block1_conv2, 8, "block1_conv2", \'conv\')

    conv_output_block2_conv1 = conv_output(model, "block2_conv1", x)
    print("block2_conv1: ", conv_output_block2_conv1)
    vis_conv(conv_output_block2_conv1, 8, "block2_conv1", \'conv\')

    conv_output_block2_conv2 = conv_output(model, "block2_conv2", x)
    print("block2_conv2: ", conv_output_block2_conv2)
    vis_conv(conv_output_block2_conv2, 8, "block2_conv2", \'conv\')

    conv_output_block3_conv1 = conv_output(model, "block3_conv1", x)
    print("block3_conv1: ", conv_output_block3_conv1)
    vis_conv(conv_output_block3_conv1, 8, "block3_conv1", \'conv\')

    conv_output_block3_conv2 = conv_output(model, "block3_conv2", x)
    print("block3_conv2: ", conv_output_block3_conv2)
    vis_conv(conv_output_block3_conv2, 8, "block3_conv2", \'conv\')

    conv_output_block5_conv3 = conv_output(model, "block5_conv3", x)
    print("block5_conv3: ", conv_output_block5_conv3)
    vis_conv(conv_output_block5_conv3, 8, "block5_conv3", \'conv\')

    print("fc1: ", conv_output(model, "fc1", x))
    print("fc2: ", conv_output(model, "fc2", x))
    print("predictions: ", conv_output(model, "predictions", x))
输出结果: 
1/1 [==============================] - 2s 2s/step
Predicted: [(\'n03532672\', \'hook\', 0.4591384), (\'n02910353\', \'buckle\', 0.032941677), (\'n01930112\', \'nematode\', 0.032439113)]
1/1 [==============================] - 0s 53ms/step
block1_conv1:  [[[  0.         42.11969     0.        ...   0.         32.04823
     0.       ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  ...
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  2.61003    32.20762   173.75212   ...   0.        517.4678
   391.77734  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 ...

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         39.011864    0.        ...   0.          0.
     0.       ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  ...
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [523.7337     25.070164  184.00014   ...   0.        144.83621
   315.41928  ]]]
1/1 [==============================] - 0s 84ms/step
block1_conv2:  [[[982.48444    65.59724     0.        ...  81.02978   698.99084
   172.65338  ]
  [256.9937    101.16306     8.7225065 ... 203.38603   340.56735
     0.       ]
  [314.77548   126.94779     0.        ... 159.34764   175.0137
     0.       ]
  ...
  [314.77548   126.94779     0.        ... 159.34764   175.0137
     0.       ]
  [ 63.2487      0.          0.        ... 125.09357   413.46884
    33.402287 ]
  [  0.          0.          0.        ...  32.059208    0.
     7.143284 ]]

 [[401.39062    97.3492      0.        ... 134.1313    454.73416
     0.       ]
  [  0.         97.926704  136.89134   ... 259.61768   632.9747
     0.       ]
  [  0.        125.44156    95.91204   ... 174.20306   390.24847
     0.       ]
  ...
  [  0.        125.44156    95.91204   ... 174.20306   390.24847
     0.       ]
  [  0.          0.        109.98622   ... 103.348114  854.354
     0.       ]
  [  0.          0.          0.        ...   0.        394.38068
     0.       ]]

 [[396.95483   167.3767      0.        ...  69.25613   207.11255
     4.1853294]
  [  0.        174.81584    76.58766   ... 161.11617   339.40433
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  ...
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [  0.          0.         70.19446   ...   0.        479.9812
   254.07501  ]
  [  0.          0.          0.        ...   0.        199.8518
    50.87436  ]]

 ...

 [[396.95483   167.3767      0.        ...  69.25613   207.11255
     4.1853294]
  [  0.        174.81584    76.58766   ... 161.11617   339.40433
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  ...
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [  0.          0.         70.19446   ...   0.        479.9812
   254.07501  ]
  [  0.          0.          0.        ...   0.        199.8518
    50.87436  ]]

 [[196.74297     0.          0.        ...  76.20704   371.12302
   239.03537  ]
  [  0.          0.         54.11582   ... 132.80391   642.51025
   472.34528  ]
  [  0.          0.          4.422485  ...   7.28855   283.40457
   706.94666  ]
  ...
  [  0.          0.          4.422485  ...   7.28855   283.40457
   706.94666  ]
  [  0.          0.         54.947617  ...   0.        688.73157
   731.2318   ]
  [  0.          0.          0.        ...   0.        364.4021
   284.65625  ]]

 [[  0.          0.          0.        ...   0.          0.
     0.       ]
  [  0.          0.          0.        ...   0.        407.869
     0.       ]
  [  0.          0.          0.        ...   0.        198.98882
   101.46747  ]
  ...
  [  0.          0.          0.        ...   0.        198.98882
   101.46747  ]
  [  0.          0.          0.        ...   0.        534.15466
    69.81046  ]
  [287.62454     0.          0.        ...   0.        764.0485
     0.       ]]]
1/1 [==============================] - 0s 76ms/step
block2_conv1:  [[[   0.          0.        146.08685  ... 1138.9917      0.
   1914.1439  ]
  [   0.          0.        617.18994  ...  630.32166     0.
      0.      ]
  [   0.          0.        479.59012  ...  803.52374     0.
    281.59882 ]
  ...
  [   0.          0.        479.59012  ...  803.52374     0.
    281.59882 ]
  [   0.          0.        583.4128   ...  895.7679      0.
    715.7333  ]
  [   0.          0.       1087.817    ... 2163.6226      0.
      0.      ]]

 [[   0.        657.53296     0.       ...  660.99      461.2479
   1719.0864  ]
  [   0.        823.556     349.60562  ...    0.        542.6992
      0.      ]
  [   0.        748.83795   131.92645  ...   30.981398  517.1108
     82.481895]
  ...
  [   0.        748.83795   131.92645  ...   30.981398  517.1108
     82.481895]
  [   0.        826.5497    252.64777  ...   64.045074  392.9257
    619.41876 ]
  [   0.        693.9135   1073.2073   ... 1989.0895    697.90814
      0.      ]]

 [[   0.        239.73143     0.       ...  901.56885   274.7921
   1343.2406  ]
  [   0.        214.44774   181.45721  ...    0.        279.94656
      0.      ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  ...
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        230.28584    60.274647 ...   54.528107   35.845345
    758.34717 ]
  [   0.        283.4764    837.31805  ... 1669.6423    417.16782
    390.9171  ]]

 ...

 [[   0.        239.73143     0.       ...  901.56885   274.7921
   1343.2406  ]
  [   0.        214.44774   181.45721  ...    0.        279.94656
      0.      ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  ...
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        230.28584    60.274647 ...   54.528107   35.845345
    758.34717 ]
  [   0.        283.4764    837.31805  ... 1669.6423    417.16782
    390.9171  ]]

 [[   0.        149.2003      0.       ...  467.1346    130.91127
   1713.3496  ]
  [   0.         89.11      283.70944  ...    0.        236.00652
      0.      ]
  [   0.         21.128517   52.216312 ...    0.        233.49413
     93.75622 ]
  ...
  [   0.         21.128517   52.216312 ...    0.        233.49413
     93.75622 ]
  [   0.        120.84711   171.13362  ...    0.         73.68687
    632.3945  ]
  [   0.        207.82211   976.44196  ... 1907.8083    525.08185
     29.64562 ]]

 [[   0.        296.92758   171.61426  ...  975.3303    292.51434
   1616.5455  ]
  [   0.        235.07794   710.6981   ...  276.39038     0.
      0.      ]
  [   0.        116.03024   512.0845   ...  650.45764    53.27237
    331.76382 ]
  ...
  [   0.        116.03024   512.0845   ...  650.45764    53.27237
    331.76382 ]
  [   0.        247.85234   603.1937   ...  753.06476    57.02111
    653.146   ]
  [   0.        435.59036  1229.345    ... 2149.0642    365.4059
      0.      ]]]
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c074790> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 96ms/step
block2_conv2:  [[[  19.134865   65.2908    388.85107  ...   77.567345    0.
      0.      ]
  [ 385.78787     0.         83.92136  ...  823.738       0.
      0.      ]
  [ 362.76718     0.          0.       ...  770.1545      0.
      0.      ]
  ...
  [ 370.19595     0.          0.       ...  693.7316      0.
      0.      ]
  [ 395.07098  1163.4445      0.       ...  685.89105     0.
      0.      ]
  [ 393.64594   221.8914      0.       ...  779.5206      0.
      0.      ]]

 [[   0.          0.        658.96985  ...  266.29254  1334.6693
      0.      ]
  [ 175.15945     0.          0.       ...  927.1358    410.14014
      0.      ]
  [ 113.65867     0.          0.       ...  705.73663   115.82475
    341.95673 ]
  ...
  [  89.81759   278.56213     0.       ...  651.8543    775.20416
    502.7654  ]
  [ 136.82233  1937.8406      0.       ...  647.9445    302.8629
    525.4279  ]
  [ 262.19644   357.42938     0.       ...  750.1874      0.
    489.33453 ]]

 [[   0.          0.        418.21606  ...   12.688118  795.45483
      0.      ]
  [ 234.67218     0.          0.       ...  426.10312     0.
      0.      ]
  [ 145.08507     0.          0.       ...  287.3707      0.
    296.64294 ]
  ...
  [ 103.087685  305.11697    62.120567 ...  267.3017    545.9968
    524.84625 ]
  [ 235.22937  2067.736     239.66722  ...  172.1788    407.2032
    489.35236 ]
  [ 323.7679    407.43408   319.0578   ...  341.47412     0.
    345.82104 ]]

 ...

 [[   0.          0.        580.24994  ...   68.54731   589.51636
      0.      ]
  [ 201.64163     0.        157.14062  ...  501.0832      0.
      0.      ]
  [ 133.07848     0.          0.       ...  351.53003     0.
    415.4161  ]
  ...
  [  86.24023   465.5442     22.741163 ...  337.74213   215.66536
    622.05804 ]
  [ 174.42499  2174.4937     46.142918 ...  286.23798   212.43034
    572.5916  ]
  [ 282.7715    504.28677   132.34572  ...  501.6414      0.
    371.98062 ]]

 [[   0.          0.        247.89134  ...  337.7562    870.8283
      0.      ]
  [ 129.28552     0.          0.       ...  976.0519      0.
      0.      ]
  [   0.        107.290855    0.       ...  696.99493     0.
    248.08282 ]
  ...
  [   0.        545.71716     0.       ...  687.88995   175.53624
    456.3958  ]
  [  40.394768 2056.4695      0.       ...  716.48956   157.10045
    438.98425 ]
  [ 169.84534   324.61182   357.57187  ...  724.79034     0.
    279.55737 ]]

 [[   0.          0.          0.       ...  108.35586  1594.9191
      0.      ]
  [   0.          0.          0.       ...  641.5959    631.3734
      0.      ]
  [   0.          0.          0.       ...  476.7445    236.77658
      0.      ]
  ...
  [   0.          0.          0.       ...  514.81213   659.1744
      0.      ]
  [   0.        558.51337     0.       ...  529.3481    646.179
      0.      ]
  [   0.          0.        318.3686   ...  567.25116     0.
     85.41164 ]]]
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c0250d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 86ms/step
block3_conv1:  [[[ 104.03467     0.       7676.7437   ...  284.6595    104.21471
    495.0637  ]
  [   0.        313.47745  5637.235    ...  773.39124   312.671
    710.7272  ]
  [   0.        626.0542   4799.9775   ...  797.72      329.52908
    588.2553  ]
  ...
  [   0.        646.8998   4819.6846   ...  770.5025    316.77924
    555.9198  ]
  [   0.        247.11465  5635.4976   ...  528.78986   281.3929
    570.0344  ]
  [  30.971907    0.       7807.489    ...  149.22829   247.03853
    569.7642  ]]

 [[   0.        871.3891   5385.873    ...  138.57967   142.74121
    983.03674 ]
  [   0.       1012.0134    499.58597  ...  162.09428   256.54013
   1158.3336  ]
  [   0.       1021.0573     28.230726 ...  184.61717   219.79193
    785.92285 ]
  ...
  [   0.       1050.2477      0.       ...  146.40399   266.05975
    744.28723 ]
  [   0.        998.98596   374.99014  ...   64.251434  274.85852
    940.72205 ]
  [   0.        715.7788   5695.999    ...  181.9697    113.495964
    998.7362  ]]

 [[   0.        715.5003   4931.604    ...  174.02486   218.0967
    733.6579  ]
  [   0.        782.053      98.879425 ...  201.88213   215.22943
    785.8126  ]
  [   0.        754.37915     0.       ...  156.65364    84.32829
    489.68857 ]
  ...
  [   0.        784.71075     0.       ...  119.22443   137.86731
    454.97656 ]
  [   0.        741.68567     0.       ...   42.16644   243.78513
    592.8224  ]
  [   0.        564.874    5148.9604   ...  128.61302   147.20853
    733.89886 ]]

 ...

 [[   0.        496.68298  4885.435    ...  318.65524   245.03665
    575.7172  ]
  [   0.        486.5161     27.83389  ...  512.1368    232.01933
    566.13635 ]
  [   0.        477.45157     0.       ...  499.04877    68.24934
    263.7914  ]
  ...
  [   0.        499.77722     0.       ...  459.50702   132.83049
    226.18076 ]
  [   0.        439.5999      0.       ...  320.2604    207.68942
    371.76605 ]
  [   0.        339.14404  5100.4336   ...  253.7242    112.67809
    590.3231  ]]

 [[   0.        347.25443  5573.7017   ...  627.8705    275.148
    631.8805  ]
  [   0.        358.82916   292.3079   ...  979.4485    303.31757
    662.5002  ]
  [   0.        478.66336     0.       ... 1011.04913   144.6257
    358.16284 ]
  ...
  [   0.        500.74857     0.       ...  972.3128    223.55475
    336.5134  ]
  [   0.        355.48328   104.18472  ...  832.22375   270.79025
    496.9038  ]
  [   0.        219.11375  5712.4497   ...  539.98773    84.06546
    667.78613 ]]

 [[   0.        604.2773   7762.388    ...  492.06854   294.44586
    373.23422 ]
  [   0.        660.0235   5493.3257   ...  210.03978   176.89102
    304.05936 ]
  [   0.        675.077    4603.5874   ...  169.29701   125.09003
     53.69849 ]
  ...
  [   0.        701.2141   4594.911    ...  142.22992   227.38722
     59.698753]
  [   0.        718.4968   5527.42     ...  161.2458    129.69702
    249.47922 ]
  [   0.        586.12274  8277.507    ...  435.10352     0.
    348.29013 ]]]
1/1 [==============================] - 0s 105ms/step
block3_conv2:  [[[   0.        971.66376   794.8841   ...  172.1506     10.597431
    794.8708  ]
  [   0.        291.7925    826.4213   ...   39.319454    0.
    718.3281  ]
  [   0.        156.54356   802.0568   ...    0.          0.
    503.39447 ]
  ...
  [   0.        401.88135  1241.3585   ...    0.          0.
    362.15497 ]
  [   0.        675.3719   1448.097    ...    0.          9.820769
    410.58932 ]
  [   0.         10.890532  953.4981   ...  233.22906     0.
    579.7396  ]]

 [[ 575.767    1863.7603    592.8948   ...  245.05453     0.
   1068.8091  ]
  [ 514.8801    844.0041    222.19751  ...    0.          0.
    788.1397  ]
  [  19.14704   444.27817   111.57798  ...    0.          0.
    409.57492 ]
  ...
  [ 252.99167   848.908     513.1679   ...    0.          0.
    312.90305 ]
  [ 591.92786  1448.2924    630.19824  ...    0.          0.
    504.8597  ]
  [   0.        379.8196    763.054    ...    0.         72.78092
    733.65424 ]]

 [[ 287.43423  1910.7128    349.80966  ...  387.3527      0.
   1265.1278  ]
  [   0.        740.3088    124.85873  ...    0.          0.
    918.3699  ]
  [   0.        286.83832   118.424774 ...    0.        177.10791
    486.00412 ]
  ...
  [   0.        735.8566    558.8175   ...    0.        193.26689
    449.90454 ]
  [  53.59411  1525.2466    651.7935   ...    0.        103.276146
    716.995   ]
  [   0.        603.8922    836.88104  ...   50.30762   191.5637
    884.57367 ]]

 ...

 [[ 292.4923   1834.398     444.55945  ...  540.1754     14.972595
   1457.0437  ]
  [   0.        642.0181    319.91138  ...   44.719204  156.22743
   1106.5459  ]
  [   0.        170.11359   338.21768  ...    0.        376.42972
    603.82666 ]
  ...
  [   0.        581.471     737.77203  ...    0.        400.47705
    579.5313  ]
  [   0.       1367.3385    798.4122   ...    0.        260.49323
    826.0262  ]
  [   0.        544.79816   826.0728   ...   77.14375   283.54224
    990.2182  ]]

 [[ 584.30676  1950.905     596.8577   ...  740.97327    81.50432
   1820.6097  ]
  [ 116.4952    835.781     588.2435   ...  225.01852   196.70117
   1720.8013  ]
  [   0.        356.6451    615.6922   ...   77.022446  354.97284
   1198.3191  ]
  ...
  [   0.        733.50824  1012.90985  ...    0.        296.32776
   1099.1088  ]
  [ 186.0945   1339.9901   1179.6779   ...    0.        191.2773
   1315.7777  ]
  [   0.        384.91098  1044.8905   ...  228.41646   209.99303
   1404.8423  ]]

 [[ 608.40894  1603.2566    899.59283  ...  999.1029     64.82636
   1448.8973  ]
  [ 733.8801   1092.808     762.7826   ...  444.4963    137.76027
   1666.6692  ]
  [ 293.81265   823.8305    784.1011   ...  267.9691    135.08733
   1363.0045  ]
  ...
  [ 359.85425  1058.6151   1013.9297   ...  163.37076   159.4037
   1266.5629  ]
  [ 682.56195  1274.0765   1125.9093   ...  177.28194   135.8132
   1424.4539  ]
  [ 148.46483   454.86966   954.3874   ...  199.56137   320.5976
   1351.9453  ]]]
1/1 [==============================] - 0s 146ms/step
block5_conv3:  [[[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 ...

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        1.2440066 0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.5987776 0.       ]
  ...
  [0.        0.        0.        ... 0

深度学习前沿应用图像分类Fine-Tuning

【深度学习前沿应用】图像分类Fine-Tuning


(文章目录)


前言

1. 什么是预训练-微调模式?

在计算机视觉领域,预训练-微调模式已经沿用了多年,即在大规模图片数据集预训练模型参数,然后将训练好的参数在新的小数据集任务上进行微调,从而产生泛化性能更好的模型。


2. 什么是ResNet?

ResNet为常用的预训练模型之一,其核心操作为卷积与残差连接。卷积层为3×3的滤波器,并遵循两个简单的设计规则:①对于相同的输出特征图尺寸,每层具有相同数量的滤波器;②如果特征图尺寸减半,则滤波器数量加倍,以保持每层的时间复杂度。直接用步长为2的卷积层进行下采样,网络以全局平均池化层和伴随softmax的1000维全连接层结束,其中,卷积层数为34,因此也称为ResNet34(如下图1所示)。

本小节将使用ResNet34预训练-微调框架,实现猫脸12分类。对于给定的猫脸,判断其所属类型。


一、数据加载及预处理

本实验数据集来源于网络开源数据集(https://aistudio.baidu.com/aistudio/datasetdetail/10954),该数据集中包含12类猫图片,总计数据量为2160,部分图片展示如下图1所示。


(一)、数据加载及预处理

首先将该数据集挂载到当前项目中,然后读取数据文件,将数据按照8:2划分为训练集与验证集

  1. 导入相关包
import os
import time
import os.path as osp
import zipfile

import numpy as np
import paddle
import paddle.nn as nn
import pandas as pd
import paddle.nn.functional as F
from PIL import Image
from paddle.io import Dataset, DataLoader
from paddle.optimizer import Adam
from paddle.vision import Compose, ToTensor, Resize
from paddle.vision.models import resnet34
from paddle.metric import Accuracy
from sklearn.model_selection import StratifiedShuffleSplit

  1. 将train划分为训练集和验证集
info = pd.read_csv(osp.join(./data, train_list.txt), sep=\\t, header=None)
images, labels = info.iloc[:, 0], info.iloc[:, 1]
split = StratifiedShuffleSplit(test_size=0.2)
train_idx, valid_idx = next(split.split(images, labels))
info_tr = info.iloc[train_idx, :]
info_va = info.iloc[valid_idx, :]
info_tr.to_csv(data/train.csv, header=False, index=False)
info_va.to_csv(data/valid.csv, header=False, index=False)

(二)、数据集封装

class CatDataset(Dataset):
    train_file = cat_12_train.zip
    test_file = cat_12_test.zip
    train_label = train_list.txt

    def __init__(self, root, mode, transform=None):
        super(CatDataset, self).__init__()
        self.root = root
        self.mode = mode
        self.transform = transform

        if not osp.isfile(osp.join(root, self.train_file)) or \\
                not osp.isfile(osp.join(root, self.train_label)) or \\
                not osp.isfile(osp.join(root, self.test_file)):
            raise ValueError(wrong data path)
        if not osp.isdir(osp.join(self.root, cat_12_train)):
            with zipfile.ZipFile(osp.join(root, self.train_file)) as f:
                f.extractall(root)
            with zipfile.ZipFile(osp.join(root, self.test_file)) as f:
                f.extractall(root)
        if mode == train:
            info = pd.read_csv(osp.join(root, train_list.txt), sep=\\t, header=None)
            self.images = info.iloc[:, 0].to_list()
            self.labels = paddle.to_tensor(
                info.iloc[:, 1].to_list()
            )
        elif mode == train_:
            info = pd.read_csv(osp.join(root, train.csv), header=None)
            self.images = info.iloc[:, 0].to_list()
            self.labels = paddle.to_tensor(
                info.iloc[:, 1].to_list()
            )
            pass
        elif mode == valid_:
            info = pd.read_csv(osp.join(root, valid.csv), header=None)
            self.images = info.iloc[:, 0].to_list()
            self.labels = paddle.to_tensor(
                info.iloc[:, 1].to_list()
            )
        else:
            images = os.listdir(os.path.join(root, cat_12_test))
            self.images = [cat_12_test/+image for image in images]
            self.labels = None

    def __getitem__(self, idx):
        image = Image.open(osp.join(self.root, self.images[idx]))
        if image.mode != RGB:
            image = image.convert(RGB)
        if self.transform is not None:
            image = self.transform(image)
       
        if self.mode == test:
            return image,
        else:
            label = self.labels[idx]
            return image, label

    def __len__(self):
        return len(self.images)

(三)、样本分类与统计

paddle.set_device(gpu if paddle.is_compiled_with_cuda() else cpu)
transform = Compose([
    Resize([224, 224]),
    ToTensor()
])


train_ds = CatDataset(./data, train_, transform)
valid_ds = CatDataset(./data, valid_, transform)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=False)
print(训练集样本数:,train_ds.__len__())
print(验证集样本数:,valid_ds.__len__())


二、预训练模型加载

paddle,vision是飞桨在视觉领域的高层API,内部封装了常用的数据集以及常用预测训练模型,如LeNet、VGG系列、ResNet系列及MobileNet系列等。本实验使用resnet34为例,演示如何进行图像分类的微调。

准备好数据集之后,加载预训练模型,调用net=resnet34(pretrained=True),设置参数pretrained为True,便可使用预训练好的参数,否则,需要从头开始训练参数(首次加载预训练参数时需要从相关专业网络中下载):


加载预训练模型,并设置类别数目为12(猫的分类)

net = resnet34(pretrained=True, num_classes=12)

三、模型微调

加载好预训练的模型之后,定义模型的优化器、评价指标等,输入领域数据,执行微调:

(一)、定义优化器

optimizer = Adam(
    parameters=net.parameters(),
    learning_rate=1e-5
)

(二)、定义损失函数

loss_fn = nn.CrossEntropyLoss()

(三)、定义准确率评价指标

metric_fn = Accuracy()

(四)、微调20轮

for epoch in range(20):
    t0 = time.time()
    net.train()
    for data, label in train_dl:
        logit = net(data)
        loss = loss_fn(logit, label.astype(int64))
        optimizer.clear_grad()
        loss.backward()
        optimizer.step() 
        
    # 验证  
    net.eval()
    loss_tr = 0.
    for data, label in train_dl:
        logit = net(data)
        label = label.astype(int64)
        loss_tr += loss_fn(logit, label).cpu().numpy()[0]
    loss_tr /= len(train_dl)

    loss_va = 0.
    for data, label in valid_dl:
        label = label.astype(int64)
        logit = net(data)
        loss_va += loss_fn(logit, label).cpu().numpy()[0]
        metric_fn.update(
            metric_fn.compute(logit, label)
        )
    loss_va /= len(valid_dl)
    acc_va = metric_fn.accumulate()
    metric_fn.reset()
    t = time.time() - t0
    print([Epoch :3d :.2fs] train loss(:.4f); valid loss(:.4f), acc(:.2f)
            .format(epoch, t, loss_tr, loss_va, acc_va))

训练过程部分输出如下图2所示:


四、模型预测

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
def show_image(file_name): 
    img = mpimg.imread(data/+file_name)
    plt.figure(figsize=(10,10))
    plt.imshow(img)  
    plt.show()


test_ds = CatDataset(./data, mode=test, transform=transform)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)
test_pred = []
with paddle.no_grad():
    for data, in test_dl:
        logit = net(data)
        pred = paddle.argmax(
            F.softmax(logit, axis=-1),
            axis=-1
        )
        test_pred.append(pred.cpu().numpy())
test_pred = np.concatenate(test_pred, axis=0) 

for image, pred in zip(test_ds.images, test_pred.astype(np.int)):
    img = mpimg.imread(data/+image)
    plt.figure(figsize=(10,10))
    plt.imshow(img)  
    plt.show()
    print(图片路径:%s, 图片预测类型:%d\\n % (image.split(/)[1], pred))

预测结果部分输出如下图3、4、5、6所示


总结

本系列文章内容为根据清华社出版的《机器学习实践》所作的相关笔记和感悟,其中代码均为基于百度飞桨开发,若有任何侵权和不妥之处,请私信于我,定积极配合处理,看到必回!!!

最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:

【**学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。**】

以上是关于深度学习网络fine-tune原理研究的主要内容,如果未能解决你的问题,请参考以下文章

深度学习前沿应用文本分类Fine-Tunning

深度学习 Fine-tune 技巧总结

深度学习中Fine-tune是什么?

深度学习前沿应用图像分类Fine-Tuning

深度学习前沿应用图像分类Fine-Tuning

深度学习前沿应用图像分类Fine-Tuning