keras中的级联模型(自动编码器+分类器)
Posted
技术标签:
【中文标题】keras中的级联模型(自动编码器+分类器)【英文标题】:cascaded model (autoencoder + classifier) in keras 【发布时间】:2017-10-22 19:01:49 【问题描述】:我正在构建一个级联模型(一个与分类器堆叠的自动编码器模型)。自编码器的输入是一组图像,自编码器的输出将被输入到预训练的分类器中。
auto_input= Input(shape=(ch, height, width), name='x_autoen')
auto_output = autoencoder(auto_input)
auto_model = Model(input=auto_input, output=auto_output)
class_output = classifier(auto_output)
class_model = Model(input=auto_output, output=class_output)
cascade_model = Model(input=auto_input, output=[auto_output, class_output])
load_classifier_weights(cascade_model, classifier_weights_path)
auto_model.compile(optimizer='sgd', loss='mean_squared_error')
class_model.compile(optimizer='sgd', loss='binary_crossentropy')
cascade_model.compile(optimizer='sgd', loss='binary_crossentropy')
但这会返回以下错误。
File "xxxx.py", line 33, in build_model
class_model = Model(input=auto_output, output=class_output)
File "/xxx/local/lib/python2.7/site-packages/keras/engine/topology.py", line 1987, in __init__
str(layers_with_complete_input))
RuntimeError: Graph disconnected: cannot obtain value for tensor x_autoen at layer "x_autoen". The following previous layers were accessed without issue: []
分类代码:
def classifier(inputs):
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
conv1 = Dropout(0.2)(conv1)
conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
#
conv2 = Convolution2D(64, 3, 3, activation='relu'', border_mode='same')(pool1)
conv2 = Dropout(0.2)(conv2)
conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
#
conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
conv3 = Dropout(0.2)(conv3)
conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
up1 = merge([UpSampling2D(size=(2, 2))(conv3), conv2], mode='concat', concat_axis=1)#192x24x24
conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up1)
conv4 = Dropout(0.2)(conv4)
conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv4)
#
up2 = merge([UpSampling2D(size=(2, 2))(conv4), conv1],, mode='concat', concat_axis=1)#96x48x48
conv5 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up2)
conv5 = Dropout(0.2)(conv5)
conv5 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv5)
#
conv6 = Convolution2D(2, 1, 1, activation='relu', border_mode='same')(conv5)
conv6 = core.Reshape((2,patch_height*patch_width))(conv6)
conv6 = core.Permute((2,1))(conv6)
conv7 = core.Activation('softmax')(conv6)
return conv7
根据大牛的评论改正后出错:
ValueError: The model expects 2 input arrays, but only received one array. Found: array with shape (1000, 1, 48, 48)
这是我用来训练级联网络的代码。
cascade_model .fit(imgs_train, imgs_train, nb_epoch=epochs, batch_size=batch_size, verbose=2, shuffle=True, validation_split=0.1, callbacks=[checkpointer])
enter code here
【问题讨论】:
首先,检查您的“分类器”模型。是全连接的吗?我的意思是输入是否通过所有层直到输出?听起来模型被消息打断了。 (将代码发布到“分类器”模型可能会有所帮助)。 是的,是丹尼尔。我已经发布了上面的代码 您的模型需要两个您定义的输出:cascade_model .fit(imgs_train, [imgs_train,classes_train]
。如果你不想给出两个输出,那么只定义一个输出。
您应该在编译之前制作class_model.trainable = false
和所有class_model.layers[i].trainable = false
。您应该在某处设置每个分类器层的权重。
你也可以在模型的某一层复制“layer.get_weights()”来对比训练前后的效果。
【参考方案1】:
确实,分类器似乎是连接的。
那我猜你可能需要给class_model
一个独立的输入。
系统可能无法在图的中间启动模型(图是层的序列,其中一层的输出作为另一层的输入)。
尽管一切似乎都相互关联,但您将图中间的张量作为模型的输入传递。这可能会产生问题。
当我这样做时,我会这样做:
class_input = Input((shapeforclassifierinput))
class_output = classifier(class_input)
class_model = Model(input=class_input, output=class_output)
#If this gives an error, then your classifier is indeed not connected
#Then I'd suggest using the Concatenate(axis=1)([UpSampling2D(size=(2, 2))(conv3), conv2])
现在,您可以通过这种方式加入两个模型:
cascade_input = Input(shape=(ch, height, width))
auto_out = auto_model(cascade_input)
class_out = class_model(auto_out)
cascade_model = Model(input=cascade_input, output=[auto_out, class_out])
【讨论】:
这种方式可以消除该错误。但是,当我尝试训练级联网络时,会出现另一个错误。请在上面找到错误的详细信息。以上是关于keras中的级联模型(自动编码器+分类器)的主要内容,如果未能解决你的问题,请参考以下文章