
Posted Han Zheng, Practitioners and T



深度学习网络fine-tune原理研究 - 以卷积神经网络为例

一、什么是预训练模型(pre-trained model)


  • VGG16/19
  • Resnet
  • Imagenet
  • COCO







  • Step1:假设我们的神经网络符合下面的形式:Y = W * X
  • Step2:现在我们要找到一个W,使得当输入X=2时,输出Y=1,也就是希望W=0.5:1 = W * 2
  • Step3:按照神经网络的基本训练过程,首先要对W进行初始化,初始化的值符合均值为0,方差为1的分布,假设W初始化为0.1:Y = 0.1 * X
  • Step4:现在开始训练FP过程,当输入X=2时,W=0.1,输出Y=0.2,这个时候实际值和目标值1的误差为0.8:1 <====== 0.2 = 0.1 * 2
  • Step5:开始BP反向传导,0.8的误差经过反向传播去更新权值W,假如这次更新为W=0.2,输出位0.4,与目标值的误差为0.6:1 <====== 0.4 = 0.2 * 2
  • Step6:可能经过10次或20次BP反向传导,W终于得到了我们想要的0.5:Y = 0.5 * X
  • Step7:如果最开始初始化的时候有人告诉你,W的值应该在0.47附近
  • Step8:那么从最开始训练,你与目标值的误差就只有0.06了,那么可能只要一步两步BP,就能将W调整到0.5:1 <====== 0.94 = 0.47 * 2

Step7就相当于给你一个预训练模型(pre-trained model),Step8就是基于这个预训练模型去微调(fine-tune)。



  • 预训练模型训练所用的语料和下游fine-tune任务的重合度:本质上,预训练模型的模型权重参数,代表的是喂入预训练模型的语料。如果预训练任何和下游fine-tune任务领域相差太大,则预训练模型的参数几乎不能起到提效的帮助,甚至可能帮倒忙。
  • 预训练模型自身的容量:理论上,如果预训练模型足够大,能够包含下游任务的一部分核心部分,则预训练模型可以通过权重重调整,在fine-tune的过程中,激活一部分神经元以及关闭一部分神经元,以此使预训练模型朝着下游任务的方向去“生长”。
  • 预训练模型使用的语料库是否足够大和种类丰富,因为这决定了预训练模型是否完成了足够的预训练,否则如果上游预训练模型没有完成收敛,接入下游fine-tune的时候,预训练模型也依然需要进行大量的微调,这对极大拖慢整体模型的收敛。反之,如果预训练模型已经基本完成了收敛,则对下游fine-tune训练的数据集要求就很小,fine-tune就可以基于一个小数据集依然可以得到较好的效果,同时也仅需要较少的训练时间。
  • 预训练模型输入层的向量化方式、张量维度、嵌入方式、编码方位、shape维度等等,和下游fine-tune任务的这些参数结构是否完全一致(或者是否具备一定的迁移性),理论上说,输入层的结构是一种特征工程的经验形式,它本身也代表了模型对目标任务的某种抽象。打个比方,用于文本生成任务的模型,如果将一个像素图片“强行转换适配”输入进去,最终训练和预测的效果都不会好。




  • 浅层卷积层提取基础特征,比如边缘,轮廓等基础特征
  • 深层卷积层提取抽象特征,比如整个脸型
  • 全连接层根据特征组合进行评分分类


  • 避免了从头开始训练,减少了训练时间,节省了计算资源
  • 避免了模型不收敛、参数不够优化、准确率低、模型泛化能力低、容易过拟合等问题







数据集3 - 数据量大,数据相似度低

在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)。





  • 通常的做法是截断预先训练好的网络的最后一层(softmax层),并用与我们自己的问题相关的新的softmax层替换它。例如,ImageNet上预先训练好的网络带有1000个类别的softmax图层。如果我们的任务是对10个类别的分类,则网络的新softmax层将由10个类别组成,而不是1000个类别。然后,我们在网络上运行预先训练的权重。确保执行交叉验证,以便网络能够很好地推广。
  • 使用较小的学习率来训练网络。由于我们预计预先训练的权重相对于随机初始化的权重已经相当不错,我们不想过快地扭曲它们太多。通常的做法是使初始学习率比用于从头开始训练(Training from scratch)的初始学习率小10倍。
  • 如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。这是因为前几个图层捕捉了与我们的新问题相关的通用特征,如曲线和边。我们希望保持这些权重不变。相反,我们会让网络专注于学习后续深层中特定于数据集的特征。






from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np

model = VGG16(weights=\'imagenet\')

img_path = \'6.webp\'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print(\'Predicted:\', decode_predictions(preds, top=3)[0])


Predicted: [(\'n03532672\', \'hook\', 0.4591384), (\'n02910353\', \'buckle\', 0.032941677), (\'n01930112\', \'nematode\', 0.032439113)]



from keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np
import cv2
import matplotlib.pyplot as plt

def vis_conv(images, n, name, t):
    """visualize conv output and conv filter.
           img: original image.
           n: number of col and row.
           t: vis type.
           name: save name.
    size = 64
    margin = 5

    if t == \'filter\':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3))
    if t == \'conv\':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin))

    for i in range(n):
        for j in range(n):
            if t == \'filter\':
                filter_img = images[i + (j * n)]
            if t == \'conv\':
                filter_img = images[..., i + (j * n)]
            filter_img = cv2.resize(filter_img, (size, size))

            # Put the result in the square `(i, j)` of the results grid
            horizontal_start = i * size + i * margin
            horizontal_end = horizontal_start + size
            vertical_start = j * size + j * margin
            vertical_end = vertical_start + size
            if t == \'filter\':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img
            if t == \'conv\':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img

    # Display the results grid
    plt.savefig(\'images/_.jpg\'.format(t, name), dpi=600)

def conv_output(model, layer_name, img):
    """Get the output of conv layer.

           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

           intermediate_output: feature map.
    # this is the placeholder for the input images
    input_img = model.input

        # this is the placeholder for the conv output
        out_conv = model.get_layer(layer_name).output
        raise Exception(\'Not layer named !\'.format(layer_name))

    # get the intermediate layer model
    intermediate_layer_model = Model(inputs=input_img, outputs=out_conv)

    # get the output of intermediate layer model
    intermediate_output = intermediate_layer_model.predict(img)

    return intermediate_output[0]

if __name__ == \'__main__\':
    model = VGG16(weights=\'imagenet\')

    img_path = \'6.webp\'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    preds = model.predict(x)
    # decode the results into a list of tuples (class, description, probability)
    # (one such list for each sample in the batch)
    print(\'Predicted:\', decode_predictions(preds, top=3)[0])

    conv_output_block1_conv1 = conv_output(model, "block1_conv1", x)
    print("block1_conv1: ", conv_output_block1_conv1)
    vis_conv(conv_output_block1_conv1, 8, "block1_conv1", \'conv\')

    conv_output_block1_conv2 = conv_output(model, "block1_conv2", x)
    print("block1_conv2: ", conv_output_block1_conv2)
    vis_conv(conv_output_block1_conv2, 8, "block1_conv2", \'conv\')

    conv_output_block2_conv1 = conv_output(model, "block2_conv1", x)
    print("block2_conv1: ", conv_output_block2_conv1)
    vis_conv(conv_output_block2_conv1, 8, "block2_conv1", \'conv\')

    conv_output_block2_conv2 = conv_output(model, "block2_conv2", x)
    print("block2_conv2: ", conv_output_block2_conv2)
    vis_conv(conv_output_block2_conv2, 8, "block2_conv2", \'conv\')

    conv_output_block3_conv1 = conv_output(model, "block3_conv1", x)
    print("block3_conv1: ", conv_output_block3_conv1)
    vis_conv(conv_output_block3_conv1, 8, "block3_conv1", \'conv\')

    conv_output_block3_conv2 = conv_output(model, "block3_conv2", x)
    print("block3_conv2: ", conv_output_block3_conv2)
    vis_conv(conv_output_block3_conv2, 8, "block3_conv2", \'conv\')

    conv_output_block5_conv3 = conv_output(model, "block5_conv3", x)
    print("block5_conv3: ", conv_output_block5_conv3)
    vis_conv(conv_output_block5_conv3, 8, "block5_conv3", \'conv\')

    print("fc1: ", conv_output(model, "fc1", x))
    print("fc2: ", conv_output(model, "fc2", x))
    print("predictions: ", conv_output(model, "predictions", x))
1/1 [==============================] - 2s 2s/step
Predicted: [(\'n03532672\', \'hook\', 0.4591384), (\'n02910353\', \'buckle\', 0.032941677), (\'n01930112\', \'nematode\', 0.032439113)]
1/1 [==============================] - 0s 53ms/step
block1_conv1:  [[[  0.         42.11969     0.        ...   0.         32.04823
     0.       ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  2.61003    32.20762   173.75212   ...   0.        517.4678
   391.77734  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]


 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         39.011864    0.        ...   0.          0.
     0.       ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [523.7337     25.070164  184.00014   ...   0.        144.83621
   315.41928  ]]]
1/1 [==============================] - 0s 84ms/step
block1_conv2:  [[[982.48444    65.59724     0.        ...  81.02978   698.99084
   172.65338  ]
  [256.9937    101.16306     8.7225065 ... 203.38603   340.56735
     0.       ]
  [314.77548   126.94779     0.        ... 159.34764   175.0137
     0.       ]
  [314.77548   126.94779     0.        ... 159.34764   175.0137
     0.       ]
  [ 63.2487      0.          0.        ... 125.09357   413.46884
    33.402287 ]
  [  0.          0.          0.        ...  32.059208    0.
     7.143284 ]]

 [[401.39062    97.3492      0.        ... 134.1313    454.73416
     0.       ]
  [  0.         97.926704  136.89134   ... 259.61768   632.9747
     0.       ]
  [  0.        125.44156    95.91204   ... 174.20306   390.24847
     0.       ]
  [  0.        125.44156    95.91204   ... 174.20306   390.24847
     0.       ]
  [  0.          0.        109.98622   ... 103.348114  854.354
     0.       ]
  [  0.          0.          0.        ...   0.        394.38068
     0.       ]]

 [[396.95483   167.3767      0.        ...  69.25613   207.11255
  [  0.        174.81584    76.58766   ... 161.11617   339.40433
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [  0.          0.         70.19446   ...   0.        479.9812
   254.07501  ]
  [  0.          0.          0.        ...   0.        199.8518
    50.87436  ]]


 [[396.95483   167.3767      0.        ...  69.25613   207.11255
  [  0.        174.81584    76.58766   ... 161.11617   339.40433
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [  0.          0.         70.19446   ...   0.        479.9812
   254.07501  ]
  [  0.          0.          0.        ...   0.        199.8518
    50.87436  ]]

 [[196.74297     0.          0.        ...  76.20704   371.12302
   239.03537  ]
  [  0.          0.         54.11582   ... 132.80391   642.51025
   472.34528  ]
  [  0.          0.          4.422485  ...   7.28855   283.40457
   706.94666  ]
  [  0.          0.          4.422485  ...   7.28855   283.40457
   706.94666  ]
  [  0.          0.         54.947617  ...   0.        688.73157
   731.2318   ]
  [  0.          0.          0.        ...   0.        364.4021
   284.65625  ]]

 [[  0.          0.          0.        ...   0.          0.
     0.       ]
  [  0.          0.          0.        ...   0.        407.869
     0.       ]
  [  0.          0.          0.        ...   0.        198.98882
   101.46747  ]
  [  0.          0.          0.        ...   0.        198.98882
   101.46747  ]
  [  0.          0.          0.        ...   0.        534.15466
    69.81046  ]
  [287.62454     0.          0.        ...   0.        764.0485
     0.       ]]]
1/1 [==============================] - 0s 76ms/step
block2_conv1:  [[[   0.          0.        146.08685  ... 1138.9917      0.
   1914.1439  ]
  [   0.          0.        617.18994  ...  630.32166     0.
      0.      ]
  [   0.          0.        479.59012  ...  803.52374     0.
    281.59882 ]
  [   0.          0.        479.59012  ...  803.52374     0.
    281.59882 ]
  [   0.          0.        583.4128   ...  895.7679      0.
    715.7333  ]
  [   0.          0.       1087.817    ... 2163.6226      0.
      0.      ]]

 [[   0.        657.53296     0.       ...  660.99      461.2479
   1719.0864  ]
  [   0.        823.556     349.60562  ...    0.        542.6992
      0.      ]
  [   0.        748.83795   131.92645  ...   30.981398  517.1108
  [   0.        748.83795   131.92645  ...   30.981398  517.1108
  [   0.        826.5497    252.64777  ...   64.045074  392.9257
    619.41876 ]
  [   0.        693.9135   1073.2073   ... 1989.0895    697.90814
      0.      ]]

 [[   0.        239.73143     0.       ...  901.56885   274.7921
   1343.2406  ]
  [   0.        214.44774   181.45721  ...    0.        279.94656
      0.      ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        230.28584    60.274647 ...   54.528107   35.845345
    758.34717 ]
  [   0.        283.4764    837.31805  ... 1669.6423    417.16782
    390.9171  ]]


 [[   0.        239.73143     0.       ...  901.56885   274.7921
   1343.2406  ]
  [   0.        214.44774   181.45721  ...    0.        279.94656
      0.      ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        230.28584    60.274647 ...   54.528107   35.845345
    758.34717 ]
  [   0.        283.4764    837.31805  ... 1669.6423    417.16782
    390.9171  ]]

 [[   0.        149.2003      0.       ...  467.1346    130.91127
   1713.3496  ]
  [   0.         89.11      283.70944  ...    0.        236.00652
      0.      ]
  [   0.         21.128517   52.216312 ...    0.        233.49413
     93.75622 ]
  [   0.         21.128517   52.216312 ...    0.        233.49413
     93.75622 ]
  [   0.        120.84711   171.13362  ...    0.         73.68687
    632.3945  ]
  [   0.        207.82211   976.44196  ... 1907.8083    525.08185
     29.64562 ]]

 [[   0.        296.92758   171.61426  ...  975.3303    292.51434
   1616.5455  ]
  [   0.        235.07794   710.6981   ...  276.39038     0.
      0.      ]
  [   0.        116.03024   512.0845   ...  650.45764    53.27237
    331.76382 ]
  [   0.        116.03024   512.0845   ...  650.45764    53.27237
    331.76382 ]
  [   0.        247.85234   603.1937   ...  753.06476    57.02111
    653.146   ]
  [   0.        435.59036  1229.345    ... 2149.0642    365.4059
      0.      ]]]
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c074790> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 96ms/step
block2_conv2:  [[[  19.134865   65.2908    388.85107  ...   77.567345    0.
      0.      ]
  [ 385.78787     0.         83.92136  ...  823.738       0.
      0.      ]
  [ 362.76718     0.          0.       ...  770.1545      0.
      0.      ]
  [ 370.19595     0.          0.       ...  693.7316      0.
      0.      ]
  [ 395.07098  1163.4445      0.       ...  685.89105     0.
      0.      ]
  [ 393.64594   221.8914      0.       ...  779.5206      0.
      0.      ]]

 [[   0.          0.        658.96985  ...  266.29254  1334.6693
      0.      ]
  [ 175.15945     0.          0.       ...  927.1358    410.14014
      0.      ]
  [ 113.65867     0.          0.       ...  705.73663   115.82475
    341.95673 ]
  [  89.81759   278.56213     0.       ...  651.8543    775.20416
    502.7654  ]
  [ 136.82233  1937.8406      0.       ...  647.9445    302.8629
    525.4279  ]
  [ 262.19644   357.42938     0.       ...  750.1874      0.
    489.33453 ]]

 [[   0.          0.        418.21606  ...   12.688118  795.45483
      0.      ]
  [ 234.67218     0.          0.       ...  426.10312     0.
      0.      ]
  [ 145.08507     0.          0.       ...  287.3707      0.
    296.64294 ]
  [ 103.087685  305.11697    62.120567 ...  267.3017    545.9968
    524.84625 ]
  [ 235.22937  2067.736     239.66722  ...  172.1788    407.2032
    489.35236 ]
  [ 323.7679    407.43408   319.0578   ...  341.47412     0.
    345.82104 ]]


 [[   0.          0.        580.24994  ...   68.54731   589.51636
      0.      ]
  [ 201.64163     0.        157.14062  ...  501.0832      0.
      0.      ]
  [ 133.07848     0.          0.       ...  351.53003     0.
    415.4161  ]
  [  86.24023   465.5442     22.741163 ...  337.74213   215.66536
    622.05804 ]
  [ 174.42499  2174.4937     46.142918 ...  286.23798   212.43034
    572.5916  ]
  [ 282.7715    504.28677   132.34572  ...  501.6414      0.
    371.98062 ]]

 [[   0.          0.        247.89134  ...  337.7562    870.8283
      0.      ]
  [ 129.28552     0.          0.       ...  976.0519      0.
      0.      ]
  [   0.        107.290855    0.       ...  696.99493     0.
    248.08282 ]
  [   0.        545.71716     0.       ...  687.88995   175.53624
    456.3958  ]
  [  40.394768 2056.4695      0.       ...  716.48956   157.10045
    438.98425 ]
  [ 169.84534   324.61182   357.57187  ...  724.79034     0.
    279.55737 ]]

 [[   0.          0.          0.       ...  108.35586  1594.9191
      0.      ]
  [   0.          0.          0.       ...  641.5959    631.3734
      0.      ]
  [   0.          0.          0.       ...  476.7445    236.77658
      0.      ]
  [   0.          0.          0.       ...  514.81213   659.1744
      0.      ]
  [   0.        558.51337     0.       ...  529.3481    646.179
      0.      ]
  [   0.          0.        318.3686   ...  567.25116     0.
     85.41164 ]]]
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c0250d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 86ms/step
block3_conv1:  [[[ 104.03467     0.       7676.7437   ...  284.6595    104.21471
    495.0637  ]
  [   0.        313.47745  5637.235    ...  773.39124   312.671
    710.7272  ]
  [   0.        626.0542   4799.9775   ...  797.72      329.52908
    588.2553  ]
  [   0.        646.8998   4819.6846   ...  770.5025    316.77924
    555.9198  ]
  [   0.        247.11465  5635.4976   ...  528.78986   281.3929
    570.0344  ]
  [  30.971907    0.       7807.489    ...  149.22829   247.03853
    569.7642  ]]

 [[   0.        871.3891   5385.873    ...  138.57967   142.74121
    983.03674 ]
  [   0.       1012.0134    499.58597  ...  162.09428   256.54013
   1158.3336  ]
  [   0.       1021.0573     28.230726 ...  184.61717   219.79193
    785.92285 ]
  [   0.       1050.2477      0.       ...  146.40399   266.05975
    744.28723 ]
  [   0.        998.98596   374.99014  ...   64.251434  274.85852
    940.72205 ]
  [   0.        715.7788   5695.999    ...  181.9697    113.495964
    998.7362  ]]

 [[   0.        715.5003   4931.604    ...  174.02486   218.0967
    733.6579  ]
  [   0.        782.053      98.879425 ...  201.88213   215.22943
    785.8126  ]
  [   0.        754.37915     0.       ...  156.65364    84.32829
    489.68857 ]
  [   0.        784.71075     0.       ...  119.22443   137.86731
    454.97656 ]
  [   0.        741.68567     0.       ...   42.16644   243.78513
    592.8224  ]
  [   0.        564.874    5148.9604   ...  128.61302   147.20853
    733.89886 ]]


 [[   0.        496.68298  4885.435    ...  318.65524   245.03665
    575.7172  ]
  [   0.        486.5161     27.83389  ...  512.1368    232.01933
    566.13635 ]
  [   0.        477.45157     0.       ...  499.04877    68.24934
    263.7914  ]
  [   0.        499.77722     0.       ...  459.50702   132.83049
    226.18076 ]
  [   0.        439.5999      0.       ...  320.2604    207.68942
    371.76605 ]
  [   0.        339.14404  5100.4336   ...  253.7242    112.67809
    590.3231  ]]

 [[   0.        347.25443  5573.7017   ...  627.8705    275.148
    631.8805  ]
  [   0.        358.82916   292.3079   ...  979.4485    303.31757
    662.5002  ]
  [   0.        478.66336     0.       ... 1011.04913   144.6257
    358.16284 ]
  [   0.        500.74857     0.       ...  972.3128    223.55475
    336.5134  ]
  [   0.        355.48328   104.18472  ...  832.22375   270.79025
    496.9038  ]
  [   0.        219.11375  5712.4497   ...  539.98773    84.06546
    667.78613 ]]

 [[   0.        604.2773   7762.388    ...  492.06854   294.44586
    373.23422 ]
  [   0.        660.0235   5493.3257   ...  210.03978   176.89102
    304.05936 ]
  [   0.        675.077    4603.5874   ...  169.29701   125.09003
     53.69849 ]
  [   0.        701.2141   4594.911    ...  142.22992   227.38722
  [   0.        718.4968   5527.42     ...  161.2458    129.69702
    249.47922 ]
  [   0.        586.12274  8277.507    ...  435.10352     0.
    348.29013 ]]]
1/1 [==============================] - 0s 105ms/step
block3_conv2:  [[[   0.        971.66376   794.8841   ...  172.1506     10.597431
    794.8708  ]
  [   0.        291.7925    826.4213   ...   39.319454    0.
    718.3281  ]
  [   0.        156.54356   802.0568   ...    0.          0.
    503.39447 ]
  [   0.        401.88135  1241.3585   ...    0.          0.
    362.15497 ]
  [   0.        675.3719   1448.097    ...    0.          9.820769
    410.58932 ]
  [   0.         10.890532  953.4981   ...  233.22906     0.
    579.7396  ]]

 [[ 575.767    1863.7603    592.8948   ...  245.05453     0.
   1068.8091  ]
  [ 514.8801    844.0041    222.19751  ...    0.          0.
    788.1397  ]
  [  19.14704   444.27817   111.57798  ...    0.          0.
    409.57492 ]
  [ 252.99167   848.908     513.1679   ...    0.          0.
    312.90305 ]
  [ 591.92786  1448.2924    630.19824  ...    0.          0.
    504.8597  ]
  [   0.        379.8196    763.054    ...    0.         72.78092
    733.65424 ]]

 [[ 287.43423  1910.7128    349.80966  ...  387.3527      0.
   1265.1278  ]
  [   0.        740.3088    124.85873  ...    0.          0.
    918.3699  ]
  [   0.        286.83832   118.424774 ...    0.        177.10791
    486.00412 ]
  [   0.        735.8566    558.8175   ...    0.        193.26689
    449.90454 ]
  [  53.59411  1525.2466    651.7935   ...    0.        103.276146
    716.995   ]
  [   0.        603.8922    836.88104  ...   50.30762   191.5637
    884.57367 ]]


 [[ 292.4923   1834.398     444.55945  ...  540.1754     14.972595
   1457.0437  ]
  [   0.        642.0181    319.91138  ...   44.719204  156.22743
   1106.5459  ]
  [   0.        170.11359   338.21768  ...    0.        376.42972
    603.82666 ]
  [   0.        581.471     737.77203  ...    0.        400.47705
    579.5313  ]
  [   0.       1367.3385    798.4122   ...    0.        260.49323
    826.0262  ]
  [   0.        544.79816   826.0728   ...   77.14375   283.54224
    990.2182  ]]

 [[ 584.30676  1950.905     596.8577   ...  740.97327    81.50432
   1820.6097  ]
  [ 116.4952    835.781     588.2435   ...  225.01852   196.70117
   1720.8013  ]
  [   0.        356.6451    615.6922   ...   77.022446  354.97284
   1198.3191  ]
  [   0.        733.50824  1012.90985  ...    0.        296.32776
   1099.1088  ]
  [ 186.0945   1339.9901   1179.6779   ...    0.        191.2773
   1315.7777  ]
  [   0.        384.91098  1044.8905   ...  228.41646   209.99303
   1404.8423  ]]

 [[ 608.40894  1603.2566    899.59283  ...  999.1029     64.82636
   1448.8973  ]
  [ 733.8801   1092.808     762.7826   ...  444.4963    137.76027
   1666.6692  ]
  [ 293.81265   823.8305    784.1011   ...  267.9691    135.08733
   1363.0045  ]
  [ 359.85425  1058.6151   1013.9297   ...  163.37076   159.4037
   1266.5629  ]
  [ 682.56195  1274.0765   1125.9093   ...  177.28194   135.8132
   1424.4539  ]
  [ 148.46483   454.86966   954.3874   ...  199.56137   320.5976
   1351.9453  ]]]
1/1 [==============================] - 0s 146ms/step
block5_conv3:  [[[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]


 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        1.2440066 0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.5987776 0.       ]
  [0.        0.        0.        ... 0





1. 什么是预训练-微调模式?


2. 什么是ResNet?







  1. 导入相关包
import os
import time
import os.path as osp
import zipfile

import numpy as np
import paddle
import paddle.nn as nn
import pandas as pd
import paddle.nn.functional as F
from PIL import Image
from paddle.io import Dataset, DataLoader
from paddle.optimizer import Adam
from paddle.vision import Compose, ToTensor, Resize
from paddle.vision.models import resnet34
from paddle.metric import Accuracy
from sklearn.model_selection import StratifiedShuffleSplit

  1. 将train划分为训练集和验证集
info = pd.read_csv(osp.join(./data, train_list.txt), sep=\\t, header=None)
images, labels = info.iloc[:, 0], info.iloc[:, 1]
split = StratifiedShuffleSplit(test_size=0.2)
train_idx, valid_idx = next(split.split(images, labels))
info_tr = info.iloc[train_idx, :]
info_va = info.iloc[valid_idx, :]
info_tr.to_csv(data/train.csv, header=False, index=False)
info_va.to_csv(data/valid.csv, header=False, index=False)


class CatDataset(Dataset):
    train_file = cat_12_train.zip
    test_file = cat_12_test.zip
    train_label = train_list.txt

    def __init__(self, root, mode, transform=None):
        super(CatDataset, self).__init__()
        self.root = root
        self.mode = mode
        self.transform = transform

        if not osp.isfile(osp.join(root, self.train_file)) or \\
                not osp.isfile(osp.join(root, self.train_label)) or \\
                not osp.isfile(osp.join(root, self.test_file)):
            raise ValueError(wrong data path)
        if not osp.isdir(osp.join(self.root, cat_12_train)):
            with zipfile.ZipFile(osp.join(root, self.train_file)) as f:
            with zipfile.ZipFile(osp.join(root, self.test_file)) as f:
        if mode == train:
            info = pd.read_csv(osp.join(root, train_list.txt), sep=\\t, header=None)
            self.images = info.iloc[:, 0].to_list()
            self.labels = paddle.to_tensor(
                info.iloc[:, 1].to_list()
        elif mode == train_:
            info = pd.read_csv(osp.join(root, train.csv), header=None)
            self.images = info.iloc[:, 0].to_list()
            self.labels = paddle.to_tensor(
                info.iloc[:, 1].to_list()
        elif mode == valid_:
            info = pd.read_csv(osp.join(root, valid.csv), header=None)
            self.images = info.iloc[:, 0].to_list()
            self.labels = paddle.to_tensor(
                info.iloc[:, 1].to_list()
            images = os.listdir(os.path.join(root, cat_12_test))
            self.images = [cat_12_test/+image for image in images]
            self.labels = None

    def __getitem__(self, idx):
        image = Image.open(osp.join(self.root, self.images[idx]))
        if image.mode != RGB:
            image = image.convert(RGB)
        if self.transform is not None:
            image = self.transform(image)
        if self.mode == test:
            return image,
            label = self.labels[idx]
            return image, label

    def __len__(self):
        return len(self.images)


paddle.set_device(gpu if paddle.is_compiled_with_cuda() else cpu)
transform = Compose([
    Resize([224, 224]),

train_ds = CatDataset(./data, train_, transform)
valid_ds = CatDataset(./data, valid_, transform)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=False)





net = resnet34(pretrained=True, num_classes=12)




optimizer = Adam(


loss_fn = nn.CrossEntropyLoss()


metric_fn = Accuracy()


for epoch in range(20):
    t0 = time.time()
    for data, label in train_dl:
        logit = net(data)
        loss = loss_fn(logit, label.astype(int64))
    # 验证  
    loss_tr = 0.
    for data, label in train_dl:
        logit = net(data)
        label = label.astype(int64)
        loss_tr += loss_fn(logit, label).cpu().numpy()[0]
    loss_tr /= len(train_dl)

    loss_va = 0.
    for data, label in valid_dl:
        label = label.astype(int64)
        logit = net(data)
        loss_va += loss_fn(logit, label).cpu().numpy()[0]
            metric_fn.compute(logit, label)
    loss_va /= len(valid_dl)
    acc_va = metric_fn.accumulate()
    t = time.time() - t0
    print([Epoch :3d :.2fs] train loss(:.4f); valid loss(:.4f), acc(:.2f)
            .format(epoch, t, loss_tr, loss_va, acc_va))



import matplotlib.image as mpimg
import matplotlib.pyplot as plt
def show_image(file_name): 
    img = mpimg.imread(data/+file_name)

test_ds = CatDataset(./data, mode=test, transform=transform)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)
test_pred = []
with paddle.no_grad():
    for data, in test_dl:
        logit = net(data)
        pred = paddle.argmax(
            F.softmax(logit, axis=-1),
test_pred = np.concatenate(test_pred, axis=0) 

for image, pred in zip(test_ds.images, test_pred.astype(np.int)):
    img = mpimg.imread(data/+image)
    print(图片路径:%s, 图片预测类型:%d\\n % (image.split(/)[1], pred))




最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:




深度学习 Fine-tune 技巧总结



