在连体 CNN 上使用 .fit_generator 时出错

Posted

技术标签:

【中文标题】在连体 CNN 上使用 .fit_generator 时出错【英文标题】:Error when using .fit_generator on Siamese CNN 【发布时间】:2020-08-25 19:32:24 【问题描述】:

我们正在尝试拟合 Siamese CNN,但在我们想要使用 .fit_generator 将数据提供给模型的最后一部分遇到了麻烦。

我们的生成器函数如下所示:

def get_batch(h, w, batch_size = 100):

    anchor =np.zeros((batch_size,h,w,3))
    positive =np.zeros((batch_size,h,w,3))
    negative =np.zeros((batch_size,h,w,3))

    while True:
    #Choose index at random
        index = np.random.choice(n_row, batch_size)
        for i in range(batch_size):
            list_ind = train_triplets.iloc[index[i],]
            #print(list_ind)
            anchor[i] =  train_data[list_ind[0]]
            positive[i] = train_data[list_ind[1]]
            negative[i] = train_data[list_ind[2]]

            anchor = anchor.astype("float32")
            positive = positive.astype("float32")
            negative = negative.astype("float32")

        yield [anchor,positive,negative]



该模型期望获得一个包含 3 个数组的列表作为 Siamese CNN 的输入。然而,我们得到以下 错误信息:

Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), but instead got the following list of 1 arrays

如果我们只是手动提供一个包含 3 个数组的列表,那么它就可以工作。这就是为什么我们怀疑错误是由 .fit_generator 函数引起的。我们必须使用 .fit_generator 函数,因为由于内存问题我们无法存储数据。

有人知道这是为什么吗?

提前谢谢。

【问题讨论】:

【参考方案1】:

根据错误,模型需要 3 个数组,而不是 3 个数组的列表。如此变化 yield [anchor,positive,negative]yield anchor,positive,negative 可能会起作用。

【讨论】:

以上是关于在连体 CNN 上使用 .fit_generator 时出错的主要内容,如果未能解决你的问题,请参考以下文章

找出两个卷积神经网络(CNN)的输出之间的距离,即连体网络

TensorFlow 中的连体神经网络

具有 LSTM 网络的连体模型无法使用 tensorflow 进行训练

用于 Keras 中句子相似性的具有 LSTM 的连体网络定期给出相同的结果

综合日语第一册第十四课

将 CNN 用于分类和数字数据等数据在理论上是不是合理?