python mnist_autoencoder.py

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python mnist_autoencoder.py相关的知识,希望对你有一定的参考价值。

import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import matplotlib.pyplot as plt

if __name__ == "__main__":
    encoding_dim = 32

    input_img = Input(shape=(784, ))
    encoded = Dense(encoding_dim, activation='relu')(input_img)
    decoded = Dense(784, activation='sigmoid')(encoded)

    autoencoder = Model(input=input_img, output=decoded)

    # create a separate encoder model
    encoder = Model(input=input_img, output=encoded)

    # create a separate decoder model
    encoded_input = Input(shape=(encoding_dim, ))
    decoder_layer = autoencoder.layers[-1]
    decoder = Model(input=encoded_input, output=decoder_layer(encoded_input))

    # compile model
    autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

    # load mnist datasets
    (X_train, _), (X_test, _) = mnist.load_data()

    X_train = X_train.astype('float32') / 255.0
    X_test = X_test.astype('float32') / 255.0

    X_train = X_train.reshape((len(X_train), np.prod(X_train.shape[1:])))
    X_test = X_test.reshape((len(X_test), np.prod(X_test.shape[1:])))

    print(X_train.shape)
    print(X_test.shape)

    # train model
    autoencoder.fit(X_train, X_train,
                    nb_epoch=50,
                    batch_size=256,
                    shuffle=True,
                    validation_data=(X_test, X_test))

    # encode and decode some digits
    encoded_imgs = encoder.predict(X_test)
    decoded_imgs = decoder.predict(encoded_imgs)

    n = 10
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(X_test[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded_imgs[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

以上是关于python mnist_autoencoder.py的主要内容,如果未能解决你的问题,请参考以下文章

001--python全栈--基础知识--python安装

Python代写,Python作业代写,代写Python,代做Python

Python开发

Python,python,python

Python 介绍

Python学习之认识python