python vgg_finetune.py

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python vgg_finetune.py相关的知识,希望对你有一定的参考价值。

from keras.applications.vgg16 import VGG16
from keras.layers import Conv2D
from keras.models import Sequential
from keras.layers import BatchNormalization
from keras.optimizers import Adam

vgg=VGG16()

p=0.4 #dropout
label_count=17

def split_at(model, layer_type):
    layers = model.layers
    layer_idx = [index for index,layer in enumerate(layers)
                 if type(layer) is layer_type][-1]
    return layers[:layer_idx+1], layers[layer_idx+1:]
 
conv_layers,fc_layers = split_at(vgg, Conv2D)

conv_model = Sequential(conv_layers)

def get_bn_layers(p):
    return [
        MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]),
        BatchNormalization(axis=1),
        Dropout(p/4),
        Flatten(),
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(p),
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(p/2),
        Dense(label_count, activation='softmax')
    ]
    



bn_model = Sequential(get_bn_layers(p))
bn_model.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

bn_model.fit(trn, y, batch_size=64, nb_epoch=3, validation_data=(val, y_val))

bn_model.optimizer.lr = 1e-4
bn_model.fit(conv_feat, trn_labels, batch_size=batch_size, nb_epoch=7, 
             validation_data=(conv_val_feat, val_labels))

以上是关于python vgg_finetune.py的主要内容,如果未能解决你的问题,请参考以下文章

001--python全栈--基础知识--python安装

Python代写,Python作业代写,代写Python,代做Python

Python开发

Python,python,python

Python 介绍

Python学习之认识python