scikit learn 删除频率较低的分类类

Posted

技术标签:

【中文标题】scikit learn 删除频率较低的分类类【英文标题】:scikitlearn remove less frequent categorical classes 【发布时间】:2019-09-20 17:48:32 【问题描述】:

我正在做一个分类任务,其中不同类的数量为 1500。我想从中删除频率小于 10 的那些类(和相应的记录)。

我可以写一个类似这样的函数:

code_freq_hash = 
for code in y:
    code_freq_hash.setdefault(code, 0)
    code_freq_hash[code] += 1

获取每个类的频率,然后删除相应的记录。 但是,我想知道在 scikit learn 或 keras 中是否有内置函数可以做到这一点

【问题讨论】:

你的数据集是 Pandas 数据框还是 numpy 数组?基于 pandas/numpy 的解决方案会起作用吗? numpy 解决方案将起作用 【参考方案1】:

这是一个使用 numpy 和 pandas 的示例解决方案。


创建具有两个特征和一个 class 列的数据集

data = np.hstack((np.array(np.random.randn(20,2)), np.random.choice(np.arange(20), (20,1))))

麻木

val, count = np.unique(data[:,-1], return_counts=True)
val[count>2]
out = data[np.isin(data[:, -1], val[np.isin(val, val[count>2])])] # replace 2 with 10 for your problem

熊猫

将数据集(numpy 数组)转换为 pandas 数据帧

df = pd.DataFrame(data)
# renamming the last column to the name "class"
df.rename(columns= df.columns[-1]: "class" , inplace=True)

    0                  1    class
0   0.542154    -0.434981   3.0
1   1.513857    -0.606722   17.0
2   0.372834    -0.120914   0.0
3   -1.357369   1.575805    5.0
4   0.547217    0.719883    4.0
5   0.818016    -0.243919   9.0
6   -0.400552   0.066519    19.0
7   0.463596    1.020041    6.0
8   0.850465    -0.814260   14.0
9   1.693060    0.186741    17.0
10  -0.287775   -0.190247   3.0
11  -0.390932   -0.418964   6.0
12  0.209542    0.797151    5.0
13  0.126585    -0.345196   5.0
14  -0.151729   -1.260708   4.0
15  -1.042408   1.050194    6.0
16  -0.221668   1.763742    5.0
17  -0.045617   1.159383    5.0
18  1.452508    -0.785115   5.0
19  2.125601    1.745009    2.0

计算出现次数并仅过滤出现两次以上的类(在您的情况下设置为 2 到 10)

d = df.loc[df['class'].isin(df['class'].value_counts().index[df['class'].value_counts() > 2])]

你可以获取numpy数组为d.values

array([[-1.35736852,  1.57580524,  5.        ],
       [ 0.46359614,  1.02004142,  6.        ],
       [-0.39093188, -0.41896435,  6.        ],
       [ 0.20954221,  0.79715056,  5.        ],
       [ 0.12658469, -0.34519613,  5.        ],
       [-1.04240815,  1.05019427,  6.        ],
       [-0.2216682 ,  1.76374209,  5.        ],
       [-0.0456175 ,  1.15938322,  5.        ],
       [ 1.45250806, -0.78511526,  5.        ]])

【讨论】:

【参考方案2】:

在 Sklearn 中没有直接的解决方案,但正如您提到的,它可以通过自定义函数来实现。

import pandas as pd
import numpy as np

df = pd.DataFrame('labels': np.random.randint(0,10,size=50000),
                  'input': np.random.choice(['sample text 1','sample text 1'],size=50000))
threshold = 5000

labels_df=df.labels.value_counts()
filtered_labels = labels_df[labels_df>threshold].index

new_df = df.loc[df['labels'].isin(filtered_labels),:]
new_df.shape
#(25290, 2)

【讨论】:

【参考方案3】:

一个解决方案可能是以下代码sn-p:

import numpy as np
unique, appearances = np.unique(a, return_counts=True)
code_freq_hash = [(unique[i], appearances[i]) for i in range(len(unique)) if appearances[i] >= 10]

更优雅,如下所述,relevant_labels = unique[appearances >= 10]

【讨论】:

以上是关于scikit learn 删除频率较低的分类类的主要内容,如果未能解决你的问题,请参考以下文章

需要帮助将 scikit-learn 应用于这个不平衡的文本分类任务

如何使用 DecisionTreeClassifier 平衡分类?

如何在训练期间为 Scikit Learn SVM 中的每个标签分配概率?

scikit-learn TF-IDF

中频IF

时钟配置