由于形状错误,Keras fit_generator()无法正常工作
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了由于形状错误,Keras fit_generator()无法正常工作相关的知识,希望对你有一定的参考价值。
我使用Keras运行MNIST预测,具有tensorflow后端。我有使用Keras fit()作为批处理运行的代码
(X_train, y_train), (X_test, y_test) = mnist.load_data()
N1 = X_train.shape[0]
N2 = X_test.shape[0]
h = X_train.shape[1]
w = X_train.shape[2]
num_pixels = h*w
# reshape N1 samples to num_pixels
x_train = X_train.reshape(N1, num_pixels).astype('float32') # shape is now (60000,784)
x_test = X_test.reshape(N2, num_pixels).astype('float32') # shape is now (10000,784)
x_train = x_train / 255
x_test = x_test / 255
y_train = np_utils.to_categorical(y_train) #(60000,10)
y_test = np_utils.to_categorical(y_test) # (10000,10):
num_classes = y_test.shape[1]
def baseline_model():
# create model
model = Sequential()
model.add(Dense(num_pixels, input_dim=num_pixels, kernel_initializer='normal', activation='relu'))
model.add(Dense(num_classes, kernel_initializer='normal', activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
model = baseline_model()
batch_size = 200
epochs = 20
max_batches = 2 * len(x_train) / batch_size # 2*60000/200
# reshape to be [samples][width][height][ channel] for ImageDataGenerator
x_t = X_train.reshape(N1, w, h, 1).astype('float32')
datagen = ImageDataGenerator(rescale= 1./255)
train_gen = datagen.flow(x_t, y_train, batch_size=batch_size)
for e in range(epochs):
batches = 0
for x_batch, y_batch in train_gen:
# x_batch is of size [batch_sz,w,h,ch]: resize to [bth_sz,pixel_sz]: (200,28,28,1)-> (200,784)
# for model.fit
x_batch = np.reshape(x_batch, [-1, num_pixels])
model.fit(x_batch, y_batch,validation_split=0.15,verbose=0)
batches += 1
print("Epoch %d/%d, Batch %d/%d" % (e+1, epochs, batches, max_batches))
if batches >= max_batches:
break
scores = model.evaluate(x_test, y_test, verbose=0)
但是,当我尝试使用fit_generator()实现类似的代码时,我收到一个错误。代码如下:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# separate data into train and validation
from sklearn.model_selection import train_test_split
# Split the data
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.15, shuffle= True)
# number of training samples
N1 = X_train.shape[0] # training size
N2 = X_test.shape[0] # test size
N3 = X_valid.shape[0] # valid size
h = X_train.shape[1]
w = X_train.shape[2]
num_pixels = h*w
y_train = np_utils.to_categorical(y_train)
y_valid = np_utils.to_categorical(y_valid)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]
def baseline_model():
# create model
model = Sequential()
model.add(Dense(num_pixels, input_dim=num_pixels, kernel_initializer='normal', activation='relu'))
model.add(Dense(num_classes, kernel_initializer='normal', activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
model = baseline_model()
batch_size = 200
epochs = 20
steps_per_epoch_tr = int(N1/ batch_size) # 51000/200
steps_per_epoch_val = int(N3/batch_size)
# reshape to be [samples][width][height][ channel] for ImageData Gnerator->datagen.flow
x_t = X_train.reshape(N1, w, h, 1).astype('float32')
x_v = X_valid.reshape(N3, w, h, 1).astype('float32')
# define data preparation
datagen = ImageDataGenerator(rescale=1./255) # scales x_t/x_v
train_gen = datagen.flow(x_t, y_train, batch_size=batch_size)
valid_gen = datagen.flow(x_v,y_valid, batch_size=batch_size)
model.fit_generator(train_gen,steps_per_epoch = steps_per_epoch_tr,validation_data = valid_gen,
validation_steps = steps_per_epoch_val,epochs=epochs)
这给出了一个错误:
这是由于预期的图像尺寸错误,但我不知道在哪里/如何解决这个问题。任何帮助是极大的赞赏。 谢谢 sedy
答案
在model.fit()的情况下,这条线在输入训练之前将输入展平。
x_batch = np.reshape(x_batch,[ - 1,num_pixels])
但是在生成器的情况下,在将输入馈送到Dense层之前没有任何东西可以平整输入。 Dense图层无法处理2D输入(28 x 28)。添加一个Flatten()层到模型应该如下所示。
def baseline_model():
# create model
model = Sequential()
model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(num_pixels, input_dim=num_pixels, kernel_initializer='normal', activation='relu'))
model.add(Dense(num_classes, kernel_initializer='normal', activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
以上是关于由于形状错误,Keras fit_generator()无法正常工作的主要内容,如果未能解决你的问题,请参考以下文章