生成器只进行12次迭代 - 无论批量大小

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了生成器只进行12次迭代 - 无论批量大小相关的知识,希望对你有一定的参考价值。

我有以下数据生成器。它工作并返回预期的数据。除了无论我设置的epochs或batchsize等于什么,它只进行12次迭代然后给出错误(见下文)

我尝试过更改时代数和批量大小。

# initialize the number of epochs to train for and batch size
NUM_EPOCHS = 10 #100
BS = 32 #64 #32

NUM_TRAIN_IMAGES = len(train_uxo_scrap)
NUM_TEST_IMAGES = len(test_uxo_scrap)
def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):
    cnt=0
    while True:
        images = []
        labels = []
        #cnt=0

        while len(images) < batchsize and cnt < len(imgfns):
            images.append(imgfns[cnt])
            labels.append(imglabels[cnt])
            cnt=cnt+1

        print(images)
        print(labels)
        print('********** cnt = ', cnt)
        yield images, labels
train_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS, class_mode='binary')

valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS, class_mode='binary')
# train the network
H = model.fit_generator(
    train_gen,
    steps_per_epoch=NUM_TRAIN_IMAGES // BS,
    validation_data=valid_gen,
    validation_steps=NUM_TEST_IMAGES // BS,
    epochs=NUM_EPOCHS)

我希望代码在每次迭代中经历10个时期,每个样本有32个样本。我每次迭代得到32个样本,但我在第一个时期只得到12个迭代,然后我得到以下错误。无论批量大小还是设定时期,都会发生这种情况。

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-83-26f81894773d> in <module>()
      5     validation_data=valid_gen,
      6     validation_steps=NUM_TEST_IMAGES // BS,
----> 7     epochs=NUM_EPOCHS)

~AppDataLocalContinuumanaconda3envsdltf1libsite-packages	ensorflowpythonkerasengine	raining.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1424         use_multiprocessing=use_multiprocessing,
   1425         shuffle=shuffle,
-> 1426         initial_epoch=initial_epoch)
   1427 
   1428   def evaluate_generator(self,

~AppDataLocalContinuumanaconda3envsdltf1libsite-packages	ensorflowpythonkerasengine	raining_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, **kwargs)
    182       # `batch_size` used for validation data if validation
    183       # data is NumPy/EagerTensors.
--> 184       batch_size = int(nest.flatten(batch_data)[0].shape[0])
    185 
    186       # Callbacks batch begin.

IndexError: tuple index out of range

以下是打印输出的示例:

['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#615.npy', ..., 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#224.npy']
[1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0]
********** cnt =  352
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#532.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#953.npy', 
...
, 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1081.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1050.npy']
[1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0]
********** cnt =  384
答案

看看这是否有效:

def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):
    while True:
        start = 0
        end = batchsize

        while start  < len(imgfns): 
            x = imgfns[start:end]
            y = imglabels[start:end]
            yield x, y

            start += batchsize
            end += batchsize

假设imgfns, imglabels是numpy数组。

以上是关于生成器只进行12次迭代 - 无论批量大小的主要内容,如果未能解决你的问题,请参考以下文章

Python记录12:迭代器+生成器+生成式

SQL SERVER 批量生成编号

Python生成器

Laravel 5.2 中的批量插入

代码详解生成器迭代器

谷歌机器学习速成课程---降低损失 (Reducing Loss):随机梯度下降法