在使用 k 折交叉验证训练训练数据后如何测试数据?
Posted
技术标签:
【中文标题】在使用 k 折交叉验证训练训练数据后如何测试数据?【英文标题】:how to test the data after training the train data with k-fold cross validation? 【发布时间】:2020-09-14 12:42:45 【问题描述】:在代码中,我有:
-
将数据集分成两部分:训练集和测试集 (7:3)。该数据集由 200 行和 9394 列组成。
定义模型
使用的交叉验证:训练集上的 10 折
每次折叠获得的准确度
获得的平均准确率:94.29%
困惑是:
-
我的做法是否正确?
是否以正确的方式使用 cross_val_predict() 来预测测试数据上的 x?
剩余任务:
-
绘制模型的准确度。
绘制模型的损失图。
任何人都可以在这方面提出建议。 抱歉写了这么长的笔记!!!
数据集如下:(这些是新闻标题和正文中每个单词的tfidf)
Unnamed: 0 Unnamed: 0.1 Label Cosine_Similarity c0 c1 c2 c3 c4 c5 ... c9386 c9387 c9388 c9389 c9390 c9391 c9392 c9393 c9394 c9395
0 0 0 Real 0.180319 0.000000 0.0 0.000000 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 1 1 Real 0.224159 0.166667 0.0 0.000000 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 2 2 Real 0.233877 0.142857 0.0 0.000000 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 3 3 Real 0.155789 0.111111 0.0 0.000000 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 4 4 Real 0.225480 0.000000 0.0 0.111111 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
代码和输出为:
df_all = pd.read_csv("C:/Users/shiva/Desktop/allinone200.csv")
dataset=df_all.values
x=dataset[0:,3:]
Y= dataset[0:,2]
encoder = LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
y = np_utils.to_categorical(encoded_Y)
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=15,shuffle=True)
x_train.shape,y_train.shape
def baseline_model():
model = Sequential()
model.add(Dense(512, activation='relu',input_dim=x_train.shape[1]))
model.add(Dense(64, activation='relu')))
model.add(Dense(2, activation='softmax'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
模型拟合代码:
estimator = KerasClassifier(build_fn=baseline_model, epochs=5, batch_size=4, verbose=1)
kf = KFold(n_splits=10, shuffle=True,random_state=15)
for train_index, test_index in kf.split(x_train,y_train):
print("Train Index: ", train_index, "\n")
print("Test Index: ", test_index)
取出结果的代码:
results = cross_val_score(estimator, x_train, y_train, cv=kf)
print results
输出:
[0.9285714 1. 0.9285714 1. 0.78571427 0.85714287
1. 1. 0.9285714 1. ]
平均准确度:`
print("Accuracy: %0.2f (+/-%0.2f)" % (results.mean()*100, results.std()*2))
输出:
Accuracy: 94.29 (+/-0.14)
预测代码:
from sklearn.model_selection import cross_val_predict
y_pred = cross_val_predict(estimator, x_test, y_test,cv=kf)
print(y_test[0])
print(y_pred[0])
输出:处理后
[1. 0.]
0
这里的预测似乎还可以。因为 1 是 REAL 而 O 是 FALSE。 y_test 为 0,y_predict 也为 0。
混淆矩阵:
import numpy as np
y_test=np.argmax(y_test, axis=1)
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
cm
输出:
array([[32, 0],
[ 1, 27]], dtype=int64)
【问题讨论】:
我不知道你的数据集,但是 200 行和 9394 列的准确度为 94% 听起来非常可疑。通常,您需要至少 2 倍的数据点(行)数量,因为根据经验,您拥有特征(列),甚至可以获得一些不错的结果。此外,在您的输出结果中,还有多个 100% 的准确度,这又是非常可疑的。此外,您从未在示例代码中定义 kf,因此无法知道带有参数 cv=kf 的 cross_val_predict 是否正常工作。 感谢@Andreas Hofmann 的快速回复。我已经放了数据集,代码中提到了kf。能否请您再审查一遍。我在正确的轨道上吗? 批量大小会影响模型的准确性吗?当我增加 batch_size(比如 batch_size= 50)时,观察到的准确率为 87%。任何人,请提出建议。 【参考方案1】:根据 Andreas 关于您的观察次数的评论,这对您有任何帮助吗:Keras - Plot training, validation and test set accuracy
最佳
【讨论】:
批量大小会影响模型的准确性吗?当我增加 batch_size(比如 batch_size=50)时,观察到的准确率为 87%。任何人,请提出建议。【参考方案2】:不幸的是,我的评论变得很长,因此我在这里尝试一下:
请看一下:https://medium.com/mini-distill/effect-of-batch-size-on-training-dynamics-21c14f7a716e 简而言之,较大的批量通常会产生更差的结果但速度更快,在您的情况下这可能无关紧要(200 行)。 其次,您没有(可重复使用的)保留,这可能会给您关于您的真实准确性的错误假设。第一次尝试的准确率超过 90% 可能意味着:过度拟合、泄漏或不平衡的数据(例如这里:https://www.kdnuggets.com/2017/06/7-techniques-handle-imbalanced-data.html)或者你真的很幸运。 K-fold 与小样本量相结合可能会导致错误的假设: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0224365
一些经验法则: 1. 您希望数据点(行)是特征(列)的 2 倍。 2. 如果你仍然得到一个好的结果,这可能意味着多方面的事情。很可能是代码或方法错误。
假设您必须预测银行的欺诈风险。如果发生欺诈的可能性是 1%,我可以为您构建一个 99% 正确的模型,只需简单地说从来没有任何欺诈......
神经网络非常强大,有好有坏。坏事是他们几乎总能找到某种模式,即使没有。如果你给他们 2000 列本质上它有点像数字“Pi”,如果你在逗号后面的数字中搜索足够长的时间,你会找到你想要的任何数字组合。 这里有更详细的解释: https://medium.com/@jennifer.zzz/more-features-than-data-points-in-linear-regression-5bcabba6883e
【讨论】:
以上是关于在使用 k 折交叉验证训练训练数据后如何测试数据?的主要内容,如果未能解决你的问题,请参考以下文章