在连体 CNN 上使用 .fit_generator 时出错
Posted
技术标签:
【中文标题】在连体 CNN 上使用 .fit_generator 时出错【英文标题】:Error when using .fit_generator on Siamese CNN 【发布时间】:2020-08-25 19:32:24 【问题描述】:我们正在尝试拟合 Siamese CNN,但在我们想要使用 .fit_generator 将数据提供给模型的最后一部分遇到了麻烦。
我们的生成器函数如下所示:
def get_batch(h, w, batch_size = 100):
anchor =np.zeros((batch_size,h,w,3))
positive =np.zeros((batch_size,h,w,3))
negative =np.zeros((batch_size,h,w,3))
while True:
#Choose index at random
index = np.random.choice(n_row, batch_size)
for i in range(batch_size):
list_ind = train_triplets.iloc[index[i],]
#print(list_ind)
anchor[i] = train_data[list_ind[0]]
positive[i] = train_data[list_ind[1]]
negative[i] = train_data[list_ind[2]]
anchor = anchor.astype("float32")
positive = positive.astype("float32")
negative = negative.astype("float32")
yield [anchor,positive,negative]
该模型期望获得一个包含 3 个数组的列表作为 Siamese CNN 的输入。然而,我们得到以下 错误信息:
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), but instead got the following list of 1 arrays
如果我们只是手动提供一个包含 3 个数组的列表,那么它就可以工作。这就是为什么我们怀疑错误是由 .fit_generator 函数引起的。我们必须使用 .fit_generator 函数,因为由于内存问题我们无法存储数据。
有人知道这是为什么吗?
提前谢谢。
【问题讨论】:
【参考方案1】:根据错误,模型需要 3 个数组,而不是 3 个数组的列表。如此变化
yield [anchor,positive,negative]
到 yield anchor,positive,negative
可能会起作用。
【讨论】:
以上是关于在连体 CNN 上使用 .fit_generator 时出错的主要内容,如果未能解决你的问题,请参考以下文章
具有 LSTM 网络的连体模型无法使用 tensorflow 进行训练