Keras:使用 model.fit() 洗牌数据不会做出改变,但 sklearn.train_test_split() 会
Posted
技术标签:
【中文标题】Keras:使用 model.fit() 洗牌数据不会做出改变,但 sklearn.train_test_split() 会【英文标题】:Keras: Shuffling data using model.fit() doesn't make a change but sklearn.train_test_split() does 【发布时间】:2021-05-01 20:03:11 【问题描述】:我是 Keras 的新手,遇到了一个我不明白的问题,到目前为止我在互联网上也找不到任何解决方案。
我使用以下几行代码在 UrbanSound8K 数据集上训练一个简单的模型:
x_train, y_train, _, _ = load_data(["data_1.pickle", "data_5.pickle"])
#x_train, _, y_train, _ = train_test_split(x_train, y_train, test_size=0.01, random_state = 0, shuffle=True)
model = Sequential()
model.add(Dense(256, input_shape=(40,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')
model.fit(x_train, y_train, validation_split=0.2, batch_size=32, epochs=50, shuffle=True)
当我训练这个模型时,它达到了大约 50% 的 val_accuracy。将 model.fit()
中的 shuffle 更改为 False 似乎没有任何影响。
但是,当我取消注释第二行并使用x_train, _, y_train, _ = train_test_split(x_train, y_train, test_size=0.01, random_state = 0, shuffle=True)
对数据集进行洗牌时,模型的 val_accuracy 达到了 80% 以上!无论model.fit()
shuffle 设置为True 还是False。
这怎么可能?在拟合模型之前对数据进行洗牌应该没有任何区别,因为它的训练数据在每个时期之前都会被洗牌?还是我误解了model.fit()
的参数shuffle?或者train_test_split()
有什么额外的魔法发生?
【问题讨论】:
【参考方案1】:您正在使用 0.2 的验证拆分。现在根据它说明的 model.fit 文档
The validation data is selected from the last samples in the x and y data provided, before shuffling.
所以我唯一能想到的是,当您不使用 train_test_split 时,model.fit 使用的验证数据始终是从未打乱的训练数据末尾获取的相同数据。当您使用 train_test_split 时,训练数据会被打乱,因此在这种情况下验证数据会有所不同。如果验证集的大小很小,这可能会对计算出的验证准确度产生巨大影响,因为两种情况下的验证样本不同。我认为 model.fit 从训练数据的末尾选择验证数据是不好的做法。它应该从训练数据中随机选择它。即使有相当多的验证样本,如果训练样本末尾的数据与其余训练数据的概率分布明显不同,这可能会导致验证准确度低得多。例如,如果您要对狗和猫进行分类,并且在训练集中,最后的所有图像都是猫,那么验证图像将都是猫。
【讨论】:
非常感谢,这确实是解决方案!为了确认这一点,我在洗牌数据上训练了网络,但在未洗牌数据上对其进行了验证,它达到了大约 50-60% 的验证准确度,所以这确实是这里的问题。 欢迎,我花了一段时间才弄明白以上是关于Keras:使用 model.fit() 洗牌数据不会做出改变,但 sklearn.train_test_split() 会的主要内容,如果未能解决你的问题,请参考以下文章
使用tf.train时,使用tf.dataset的Keras model.fit()会失败
当训练数据是图像时,Keras model.fit() 中的“批次”是啥
keras.models.Model.fit 中的“时代”是啥?
使用 keras.utils.Sequence 和 keras.model.fit_generator 时出现 KeyError。