python keras_cnn_cifar10.py

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python keras_cnn_cifar10.py相关的知识,希望对你有一定的参考价值。

# -*- coding: utf-8 -*-
from __future__ import print_function

from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD
from keras.utils import np_utils
from keras.callbacks import EarlyStopping

from utils import draw_accuracy, draw_loss

batch_size = 32
nb_classes = 10
nb_epoch = 200
data_augmentation = True

# input image dimensions
img_rows, img_cols = 32, 32

# RGB
img_channels = 3

if __name__ == "__main__":
    # load cifar10
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()

    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255.0
    X_test /= 255.0

    print('X_train.shape:', X_train.shape)
    print(X_train.shape[0], 'train samples')
    print(X_test.shape[0], 'test samples')

    y_train = np_utils.to_categorical(y_train, nb_classes)
    y_test = np_utils.to_categorical(y_test, nb_classes)

    # build model
    model = Sequential()

    model.add(Convolution2D(32, 3, 3, border_mode='same',
                            input_shape=(img_channels, img_rows, img_cols)))
    model.add(Activation('relu'))

    model.add(Convolution2D(32, 3, 3))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Convolution2D(64, 3, 3, border_mode='same'))
    model.add(Activation('relu'))

    model.add(Convolution2D(64, 3, 3))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    # training
    sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='categorical_crossentropy',
                  optimizer=sgd,
                  metrics=['accuracy'])

    # early stopping
    early_stopping = EarlyStopping(monitor='val_loss', patience=2)

    if not data_augmentation:
        print('Not using data augmentation.')
        hist = model.fit(X_train, y_train,
                         batch_size=batch_size,
                         nb_epoch=nb_epoch,
                         validation_split=0.1,
                         shuffle=True,
                         callbacks=[early_stopping])
    else:
        print('Using real-time data augmentation.')

        # this will do preprocessing and realtime data augmentation
        datagen = ImageDataGenerator(
            featurewise_center=False,
            samplewise_center=False,
            featurewise_std_normalization=False,
            samplewise_std_normalization=False,
            zca_whitening=False,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            vertical_flip=False)

        # compute quantities required for featurewise normalization
        datagen.fit(X_train)

        # fit the model on the batches generated by datagen.flow()
        hist = model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
                                   samples_per_epoch=X_train.shape[0],
                                   nb_epoch=nb_epoch,
                                   validation_data=(X_test, y_test))

    # evaluation
    score = model.evaluate(X_test, y_test, verbose=1)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])

    # draw accuracy/loss plot
    draw_accuracy(hist, title='cifar10_cnn')
    draw_loss(hist, title='cifar10_cnn')

以上是关于python keras_cnn_cifar10.py的主要内容,如果未能解决你的问题,请参考以下文章

win10安不了python怎么办

Python/Windows 10 - 有没有办法防止 Windows 10 计算机在 Python 中进入睡眠状态?

python随机生成100内的10个整数?

Centos7.9安装python3.10

Python模块之random

Python进阶---python strip() split()函数实战(转)