Sklearn train_test_split 拆分数据集以将预测标签与地面实况标签进行比较
Posted
技术标签:
【中文标题】Sklearn train_test_split 拆分数据集以将预测标签与地面实况标签进行比较【英文标题】:Sklearn train_test_split split a dataset to compare predicted labels with ground truth labels 【发布时间】:2021-10-23 07:10:58 【问题描述】:我正在尝试通过改编自 this guide 使用带有小数据集的 SVM 执行多类文本分类。输入 csv 包含一个“文本”列和一个“标签”列(已为此特定任务手动分配)。
需要为每个文本条目分配一个标签。通过使用 LinearSVC 模型和 TfidfVectorizer,我获得了 75% 的准确度分数,这对于只有 400 个样本的非常小的数据集来说似乎比预期的要高。为了进一步提高准确性,我想查看未正确分类的条目,但在这里我有一个问题。因为我是这样使用 train_test_split 的:
Train_X, Test_X, Train_Y, Test_Y = train_test_split(X, y, test_size=0.1, random_state = 1004)
我不知道 train_test_split 函数使用了哪些文本条目(据我了解,该函数随机选择 test_size 的 10% 条目)。所以我不知道我应该将测试数据集的预测标签列表与语料库原始条目标签的哪个子集进行比较。换句话说,是否有一种方法可以强制为 test_size 分配一个子集,即数据集中 400 个总条目中的最后 40 个条目?
这将有助于手动比较预测标签与真实标签。
下面是代码:
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import os
class Config:
# Data and output directory config
data_path = r'./take3/Data'
code_train = r'q27.csv'
if __name__ == "__main__":
print('--------Code classification--------\n')
Corpus = pd.read_csv(os.path.join(Config.data_path, Config.code_train), sep = ',', encoding='cp1252', usecols=['text', 'label'])
train_text = ['' if type(t) == float else t for t in Corpus['text'].values]
# todo fine tunining
tfidf = TfidfVectorizer(
sublinear_tf=True,
min_df=3, norm='l2',
encoding='latin-1',
ngram_range=(1, 2),
stop_words='english')
X = tfidf.fit_transform(train_text) # Learn vocabulary and idf, return document-term matrix.
# print('Array mapping from feature integer indices to feature name',tfidf.get_feature_names())
print('X.shape:', X.shape)
y = np.array(list(Corpus['label']))
print('The corpus original labels:',y)
print('y.shape:', y.shape)
Train_X, Test_X, Train_Y, Test_Y = train_test_split(X, y, test_size=0.1, random_state = 1004)
model = LinearSVC(random_state=1004)
model.fit(Train_X, Train_Y)
SVM_predict_test = model.predict(Test_X)
accuracy = accuracy_score(Test_Y, SVM_predict_test, normalize=True, sample_weight=None)*100
print('Predicted labels for the test dataset', SVM_predict_test)
print("SVM accuracy score: :.4f".format(accuracy))
这是接收到的输出:
--------Code classification--------
X.shape: (400, 136)
The corpus original labels: [15 20 9 14 98 12 3 4 4 22 99 3 98 20 99 1 10 20 8 15 98 12 18 7
20 99 8 8 13 2 8 6 22 4 98 5 98 12 18 8 98 18 24 4 3 19 12 5
20 6 8 15 5 14 19 22 16 10 24 16 98 8 8 16 2 20 4 8 20 6 22 98
3 98 15 12 2 13 5 8 8 1 10 16 20 12 7 20 98 22 99 10 12 8 8 16
16 4 4 99 20 8 16 2 12 15 16 10 5 22 8 7 7 4 5 12 16 14 1 10
22 20 4 4 5 99 16 3 5 22 99 5 3 4 4 3 6 99 8 20 2 10 98 6
6 8 99 3 8 99 2 5 15 6 6 7 8 14 9 4 20 3 99 5 98 15 5 5
20 10 4 99 99 16 22 8 10 22 98 12 3 5 9 99 14 8 9 18 20 14 15 20
20 1 6 23 22 20 6 1 18 8 12 10 15 10 6 10 3 4 8 24 14 22 5 3
22 24 98 98 98 4 15 19 5 8 1 17 16 6 22 19 4 8 2 15 12 99 16 8
9 1 8 22 14 5 20 2 10 10 22 12 98 3 19 5 98 14 19 22 18 16 98 16
6 4 24 98 24 98 15 1 3 99 5 10 22 4 16 98 22 1 8 4 20 8 8 5
20 4 3 20 22 4 20 12 7 21 5 4 16 8 22 20 99 5 6 99 8 3 4 99
6 8 12 3 10 4 8 5 14 20 6 99 4 4 6 4 98 21 1 23 20 98 19 6
4 22 98 98 20 10 8 10 19 16 14 98 14 12 10 4 22 14 3 98 10 20 98 10
9 7 3 8 3 6 6 98 8 99 1 20 18 8 2 6 99 99 99 14 14 16 20 99
1 98 23 6 12 4 1 3 99 99 3 22 5 7 16 99]
y.shape: (400,)
Predicted labels for the test dataset [ 1 8 5 4 15 10 14 12 6 8 8 16 98 20 7 99 99 12 99 24 4 98 99 3
20 3 6 14 18 98 99 22 4 99 4 10 14 4 3 98]
SVM accuracy score: 75.0000
【问题讨论】:
【参考方案1】:train_test_split
的默认行为是将数据拆分为随机训练和测试子集。您可以通过设置 shuffle=False
并删除 random_state
来强制执行静态子集拆分。
Train_X, Test_X, Train_Y, Test_Y = train_test_split(X, y, test_size=0.1, shuffle=False)
见How to get a non-shuffled train_test_split in sklearn
【讨论】:
通过将 train_test_split 函数与答案中建议的参数一起应用,我能够将数据集的最后 10% 作为测试数据集,然后手动将预测标签与基本事实进行比较。谢谢。以上是关于Sklearn train_test_split 拆分数据集以将预测标签与地面实况标签进行比较的主要内容,如果未能解决你的问题,请参考以下文章
我无法导入 sklearn.model_selection.train_test_split
sklearn.model_selection.train_test_split 中分层方法的(无效参数)错误