python vgg_finetune.py
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python vgg_finetune.py相关的知识,希望对你有一定的参考价值。
from keras.applications.vgg16 import VGG16
from keras.layers import Conv2D
from keras.models import Sequential
from keras.layers import BatchNormalization
from keras.optimizers import Adam
vgg=VGG16()
p=0.4 #dropout
label_count=17
def split_at(model, layer_type):
layers = model.layers
layer_idx = [index for index,layer in enumerate(layers)
if type(layer) is layer_type][-1]
return layers[:layer_idx+1], layers[layer_idx+1:]
conv_layers,fc_layers = split_at(vgg, Conv2D)
conv_model = Sequential(conv_layers)
def get_bn_layers(p):
return [
MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]),
BatchNormalization(axis=1),
Dropout(p/4),
Flatten(),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(p),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(p/2),
Dense(label_count, activation='softmax')
]
bn_model = Sequential(get_bn_layers(p))
bn_model.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
bn_model.fit(trn, y, batch_size=64, nb_epoch=3, validation_data=(val, y_val))
bn_model.optimizer.lr = 1e-4
bn_model.fit(conv_feat, trn_labels, batch_size=batch_size, nb_epoch=7,
validation_data=(conv_val_feat, val_labels))
以上是关于python vgg_finetune.py的主要内容,如果未能解决你的问题,请参考以下文章
001--python全栈--基础知识--python安装
Python代写,Python作业代写,代写Python,代做Python
Python开发
Python,python,python
Python 介绍
Python学习之认识python