Keras:网络不使用 fit_generator() 进行训练
Posted
技术标签:
【中文标题】Keras:网络不使用 fit_generator() 进行训练【英文标题】:Keras: network doesn't train with fit_generator() 【发布时间】:2017-06-20 21:38:12 【问题描述】:我在大型数据集上使用 Keras(使用 MagnaTagATune 数据集进行音乐自动标记)。因此,我尝试将 fit_generator() 功能与自定义数据生成器一起使用。但是损失函数和指标的值在训练过程中不会改变。看起来我的网络根本没有训练。
当我使用 fit() 函数而不是 fit_generator() 时,一切正常,但我无法将整个数据集保存在内存中。
我已经尝试过 Theano 和 TensorFlow 后端
主要代码:
if __name__ == '__main__':
model = models.FCN4()
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'categorical_accuracy', 'precision', 'recall'])
gen = mttutils.generator_v2(csv_path, melgrams_dir)
history = model.fit_generator(gen.generate(0,750),
samples_per_epoch=750,
nb_epoch=80,
validation_data=gen.generate(750,1000,False),
nb_val_samples=250)
# RESULTS SAVING
np.save(output_history, history.history)
model.save(output_model)
类生成器_v2:
genres = ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast',
'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing',
'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal',
'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice',
'female vocal', 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice', 'choral']
def __init__(self, csv_path, melgrams_dir):
def get_dict_vals(dictionary, keys):
vals = []
for key in keys:
vals.append(dictionary[key])
return vals
self.melgrams_dir = melgrams_dir
with open(csv_path, newline='') as csvfile:
reader = csv.DictReader(csvfile, dialect='excel-tab')
self.labels = []
for row in reader:
labels_arr = np.array(get_dict_vals(
row, self.genres)).astype(np.int)
labels_arr = labels_arr.reshape((1, labels_arr.shape[0]))
if (np.sum(labels_arr) > 0):
self.labels.append((row['mp3_path'], labels_arr))
self.size = len(self.labels)
def generate(self, begin, end):
while(1):
for count in range(begin, end):
try:
item = self.labels[count]
mels = np.load(os.path.join(
self.melgrams_dir, item[0] + '.npy'))
tags = item[1]
yield((mels, tags))
except FileNotFoundError:
continue
要为 fit() 函数准备数组,我使用以下代码:
def TEST_get_data_array(csv_path, melgrams_dir):
gen = generator_v2(csv_path, melgrams_dir).generate(0,100)
item = next(gen)
x = np.array(item[0])
y = np.array(item[1])
for i in range(0,100):
item = next(gen.training)
x = np.concatenate((x,item[0]),axis = 0)
y = np.concatenate((y,item[1]),axis = 0)
return(x,y)
对不起,如果我的代码风格不好。谢谢!
UPD 1:
我尝试使用return(X,y)
而不是yield(X,y)
,但没有任何变化。
我的新生成器类的一部分:
def generate(self):
if((self.count < self.begin) or (self.count >= self.end)):
self.count = self.begin
item = self.labels[self.count]
mels = np.load(os.path.join(self.melgrams_dir, item[0] + '.npy'))
tags = item[1]
self.count = self.count + 1
return((mels, tags))
def __next__(self): # fit_generator() uses this method
return self.generate()
fit_generator 调用:
history = model.fit_generator(tr_gen,
samples_per_epoch = tr_gen.size,
nb_epoch = 120,
validation_data = val_gen,
nb_val_samples = val_gen.size)
日志:
Epoch 1/120
10554/10554 [==============================] - 545s - loss: 1.7240 - acc: 0.8922
Epoch 2/120
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820
Epoch 3/120
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820
Epoch 4/120
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820
... etc (loss is always 1.8922; acc is always 0.8820)
【问题讨论】:
在for count in range(begin, end)
之前,您可以对数据进行洗牌。
@Ladislao 我也面临同样的问题。你能告诉我你是按照什么程序来解决这个问题的吗?提前谢谢
@prasanna,正如 cmets 中提到的最佳答案,我刚刚在批次中放置了更多元素,它有所帮助。
【参考方案1】:
在使用 yield 方法时,我遇到了和你一样的问题。所以我只存储了当前索引并使用 return 语句在每次调用时返回一批。
所以我只使用了return (X, y)
而不是yield (X,y)
并且它起作用了。我不确定这是为什么。如果有人能对此有所了解,那就太酷了。
编辑: 您需要将生成器传递给函数,而不仅仅是调用函数。像这样:
model.fit_generator(gen, samples_per_epoch=750,
nb_epoch=80,
validation_data=gen,
nb_val_samples=250)
Keras 将调用您的 __next__ 函数,同时对数据进行训练。
【讨论】:
我试过了,但没有任何改变。请检查我是否理解正确(我的带有return
语句的代码在主帖的末尾)。谢谢!
像这样通过生成器时应该可以工作。如果没有,您可以发布您的错误消息吗?
是的,我正在像这样将我的生成器传递给fit_generator
函数。没有异常或错误。问题是损失函数的值在训练过程中没有改变(我已经在主帖中添加了日志)。看起来网络没有刷新它的权重。这不可能是模型中的错误,因为fit
函数(使用数组而不是生成器)可以正常工作。
您的批量大小为 1。尝试每次迭代将更多元素传递给模型。这意味着让您的下一个方法返回例如 32 个元素。也许您的类内差异太大而无法使用 1 作为batch_size。
您找到解决问题的方法了吗?【参考方案2】:
在“generate”方法中,有一个while语句。
def generate(self, begin, end):
while(1): # this
for count in range(begin, end):
try:
# something
yield(...)
except FileNotFoundError:
continue
我觉得这个说法是不需要的,所以
def generate(self, begin, end):
for count in range(begin, end):
try:
# something
yield(...)
except FileNotFoundError:
continue
【讨论】:
引发异常:File "/usr/local/lib/python3.4/dist-packages/keras/engine/training.py", line 1528, in fit_generator str(generator_output)) ValueError: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: None
Generator 必须是无穷无尽的,因为它必须在下一个 epoch 返回同一批数据以上是关于Keras:网络不使用 fit_generator() 进行训练的主要内容,如果未能解决你的问题,请参考以下文章
如何在Keras中使用fit_generator()来加权? [关闭]
如何在 keras fit_generator() 中定义 max_queue_size、workers 和 use_multiprocessing?
keras/scikit-learn:使用 fit_generator() 进行交叉验证
keras 入门整理 如何shuffle,如何使用fit_generator
keras训练函数fit和fit_generator对比,图像生成器ImageDataGenerator数据增强
使用 keras.utils.Sequence 和 keras.model.fit_generator 时出现 KeyError。