Sklearn Linear SVM 无法在多标签分类中进行训练

Posted

技术标签:

【中文标题】Sklearn Linear SVM 无法在多标签分类中进行训练【英文标题】:Sklearn Linear SVM cannot train in multilabel classification 【发布时间】:2021-05-19 00:52:46 【问题描述】:

我想使用以下代码训练具有多标签分类的线性 SVM:

from sklearn.svm import LinearSVC
from sklearn.multioutput import MultiOutputClassifier
import numpy as np

data = np.loadtxt('tictac_multi.txt')
X = data[:,:9]
y = data[:,9:]

clf = MultiOutputClassifier(LinearSVC(random_state=0, tol=1e-5, C=100, penalty='l2',max_iter=2000))
clf.fit(X, y)
print(clf.score(X, y))

数据集可以在这里找到https://www.connellybarnes.com/work/class/2016/deep_learning_graphics/proj1/tictac_multi.txt

我尝试调整不同的参数,例如 C、tol、max_iter 等。线性 SVM 模型仍然不能很好地训练。无论我调整任何参数,训练精度仍然低于 0.01...

以上代码的输出是:

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
0.011601282246985194

与当前代码相比,精度为 0.0116。

【问题讨论】:

您好!您提供的链接无效。您能否提供输入数据的摘录并更好地描述您的分类问题(输入是什么、类别等)? @PieCot 感谢您与我们联系。问题是多标签二元问题。每个输入是一个包含 9 个值的向量,这些值将是 0、1 或 -1。每个输出也是一个有 9 个值的向量,值将是 0 或 1。 例如,输入向量 = [0 1 1 0 0 -1 0 0 0] 输出向量=[1 0 0 1 1 0 1 1 1]。 @PieCot 我更新了这条线。你能不能再试一次。非常感谢。 我认为您的数据不是线性可分的。在这种情况下,线性 SVM 将无济于事。 【参考方案1】:

它看起来像一个“井字游戏”数据集(从文件名和格式来看)。

假设数据集的前 9 列提供了对游戏特定时刻 9 个单元格的描述,而其他 9 列表示对应于好棋步的单元格,您可以逐个单元格训练分类器,在为了预测一个单元格是否是一个好的移动。

因此,您实际上需要训练 9 个二元分类器,而不是一个。我基于这个想法在下面的代码中勾勒了一个非常简单的方法。在训练/测试 (80/20) 中拆分数据集后,从简单的交叉验证开始:

import numpy as np
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_validate, train_test_split
from sklearn.metrics import classification_report
import pandas as pd

# Load data, creating a Dataframe holding input and outputs
df = pd.read_csv('tictac_multi.txt', sep=' ', header=None)[list(range(18))].copy()
df.columns = pd.MultiIndex.from_product((('input', 'output'), [f'xi' for i in range(1, 10)]))

# split dataset 80/20 (also shuffle it)
X_train, X_test, y_train, y_test = train_test_split(df['input'].values, df['output'].values, test_size=0.2, random_state=42)

# Get scores from cross validation 
scores = 
    s: cross_validate(
        LinearSVC(random_state=0, dual=False, class_weight='balanced', tol=1e-5), 
        X_train, y_train[:, i], cv=5, scoring=['balanced_accuracy', 'precision', 'recall', 'f1_weighted'], 
        n_jobs=-1,
    ) for i, (s, clf) in enumerate(sorted(clfs.items()))

如您所见,我为分类器使用了一些非默认选项 (dual=False, class_weight='balanced'):它们只是有根据的猜测,您应该进行更多调查以更好地理解数据和问题,然后寻找模型的最佳参数(例如,网格搜索)。

这里是分数:

'x1': 'fit_time': array([0.01000571, 0.00814652, 0.00937247, 0.00622296, 0.00536656]),
  'score_time': array([0.01159358, 0.00597596, 0.00835085, 0.00647163, 0.00619125]),
  'test_balanced_accuracy': array([0.52209841, 0.51820565, 0.53743952, 0.55455645, 0.53620968]),
  'test_precision': array([0.25454545, 0.25      , 0.26611227, 0.27659574, 0.26295585]),
  'test_recall': array([0.5060241 , 0.52016129, 0.51612903, 0.5766129 , 0.55241935]),
  'test_f1_weighted': array([0.56543736, 0.55328701, 0.58232694, 0.57711117, 0.56292617]),
 'x2': 'fit_time': array([0.00737047, 0.00885296, 0.00616217, 0.00707698, 0.0071764 ]),
  'score_time': array([0.00650406, 0.00595641, 0.00623679, 0.00636506, 0.00567913]),
  'test_balanced_accuracy': array([0.57367382, 0.5342687 , 0.55287658, 0.56565243, 0.57909451]),
  'test_precision': array([0.22520661, 0.20041754, 0.21073559, 0.22037422, 0.23175966]),
  'test_recall': array([0.5828877 , 0.51336898, 0.56684492, 0.56684492, 0.57446809]),
  'test_f1_weighted': array([0.6183652 , 0.60068273, 0.59707974, 0.61584554, 0.63060231]),
 'x3': 'fit_time': array([0.0067966 , 0.00759745, 0.00617337, 0.00679278, 0.00650382]),
  'score_time': array([0.00605631, 0.00537109, 0.00551271, 0.00665474, 0.00649571]),
  'test_balanced_accuracy': array([0.52683332, 0.54103562, 0.56227539, 0.53312408, 0.51986383]),
  'test_precision': array([0.25502008, 0.26639344, 0.28367347, 0.26035503, 0.25      ]),
  'test_recall': array([0.51626016, 0.52845528, 0.56275304, 0.53441296, 0.53036437]),
  'test_f1_weighted': array([0.56805171, 0.58208858, 0.59506983, 0.56776364, 0.55079222]),
 'x4': 'fit_time': array([0.00649667, 0.00767159, 0.00802064, 0.00769711, 0.00611663]),
  'score_time': array([0.00572419, 0.00529647, 0.00616765, 0.00592041, 0.00609517]),
  'test_balanced_accuracy': array([0.53369766, 0.57259312, 0.57644138, 0.55746825, 0.51877354]),
  'test_precision': array([0.19791667, 0.22290389, 0.22540984, 0.21489362, 0.18930041]),
  'test_recall': array([0.51351351, 0.58602151, 0.59139785, 0.54301075, 0.49462366]),
  'test_f1_weighted': array([0.6005693 , 0.615313  , 0.61784599, 0.61784823, 0.58924053]),
 'x5': 'fit_time': array([0.00650501, 0.005898  , 0.00682783, 0.00629449, 0.00635648]),
  'score_time': array([0.00553894, 0.0059135 , 0.00625896, 0.00583744, 0.00580502]),
  'test_balanced_accuracy': array([0.51108635, 0.50499149, 0.52183641, 0.53230958, 0.51296946]),
  'test_precision': array([0.30185185, 0.29735234, 0.31163708, 0.322     , 0.30522088]),
  'test_recall': array([0.53094463, 0.47557003, 0.51465798, 0.52272727, 0.49350649]),
  'test_f1_weighted': array([0.5248707 , 0.53861778, 0.54612005, 0.55679291, 0.54217533]),
 'x6': 'fit_time': array([0.00703621, 0.00908065, 0.00665092, 0.00619102, 0.00814819]),
  'score_time': array([0.00568652, 0.00626183, 0.00584817, 0.00574327, 0.00552726]),
  'test_balanced_accuracy': array([0.55457928, 0.55569106, 0.50701258, 0.53690769, 0.56919396]),
  'test_precision': array([0.2145749 , 0.21621622, 0.18480493, 0.20416667, 0.22540984]),
  'test_recall': array([0.56084656, 0.55026455, 0.47619048, 0.51851852, 0.57894737]),
  'test_f1_weighted': array([0.60241544, 0.61008882, 0.5813744 , 0.60080544, 0.6130977 ]),
 'x7': 'fit_time': array([0.0070405 , 0.00908256, 0.00702643, 0.00635576, 0.00632381]),
  'score_time': array([0.00546646, 0.00674367, 0.00542998, 0.00671315, 0.00549483]),
  'test_balanced_accuracy': array([0.53124816, 0.52187224, 0.54180051, 0.57438252, 0.52764072]),
  'test_precision': array([0.27054108, 0.26235741, 0.27659574, 0.30364372, 0.26824034]),
  'test_recall': array([0.52325581, 0.53488372, 0.55642023, 0.58365759, 0.48638132]),
  'test_f1_weighted': array([0.56745684, 0.54860915, 0.56677092, 0.5996452 , 0.57954721]),
 'x8': 'fit_time': array([0.00761437, 0.00997519, 0.006984  , 0.00623441, 0.00683069]),
  'score_time': array([0.00540686, 0.00635052, 0.00645804, 0.00535131, 0.00548935]),
  'test_balanced_accuracy': array([0.51471322, 0.56996108, 0.52712724, 0.5443143 , 0.55319282]),
  'test_precision': array([0.18661258, 0.22292994, 0.192607  , 0.20408163, 0.20874751]),
  'test_recall': array([0.49462366, 0.56451613, 0.53513514, 0.54054054, 0.56756757]),
  'test_f1_weighted': array([0.58328382, 0.62374708, 0.57815794, 0.60051373, 0.59779516]),
 'x9': 'fit_time': array([0.00723267, 0.0069263 , 0.00828266, 0.00672913, 0.00750995]),
  'score_time': array([0.00545311, 0.00556946, 0.00732398, 0.0056181 , 0.00555682]),
  'test_balanced_accuracy': array([0.53490307, 0.55281703, 0.58447809, 0.52272419, 0.54294236]),
  'test_precision': array([0.26388889, 0.27868852, 0.29811321, 0.25506073, 0.27198364]),
  'test_recall': array([0.53413655, 0.54618474, 0.63453815, 0.5060241 , 0.532     ]),
  'test_f1_weighted': array([0.56987212, 0.58922553, 0.59075641, 0.56631422, 0.5819019 ])

如您所见,它们不是很好,但远非 0。

现在,在整个训练数据集上训练模型,并在测试数据上评估性能:

def train_clfs(clfs, X, y):
    return s: clf.fit(X, y[:, i]) for i, (s, clf) in enumerate(sorted(clfs.items()))


def get_predictions(clfs, inp):
    return s: clf.predict(inp) for s, clf in clfs.items()

# Train the classifiers
clfs = s: LinearSVC(random_state=0, dual=False, class_weight='balanced', tol=1e-5) for s in sorted(df['output'].columns)
clfs = train_clfs(clfs, X_train, y_train)

# Try them on the test values
pred = get_predictions(clfs, X_test)

# Get the classification report for each classifier
cl_report = s: classification_report(y_test[:, i], p) for i, (s, p) in enumerate(sorted(pred.items()))

这里是表演:

x1
              precision    recall  f1-score   support

           0       0.76      0.52      0.62       988
           1       0.25      0.49      0.33       323

    accuracy                           0.51      1311
   macro avg       0.50      0.51      0.48      1311
weighted avg       0.63      0.51      0.55      1311


x2
              precision    recall  f1-score   support

           0       0.87      0.56      0.68      1086
           1       0.22      0.58      0.31       225

    accuracy                           0.57      1311
   macro avg       0.54      0.57      0.50      1311
weighted avg       0.75      0.57      0.62      1311


x3
              precision    recall  f1-score   support

           0       0.79      0.50      0.61       998
           1       0.26      0.57      0.36       313

    accuracy                           0.52      1311
   macro avg       0.53      0.54      0.49      1311
weighted avg       0.66      0.52      0.55      1311


x4
              precision    recall  f1-score   support

           0       0.84      0.54      0.65      1061
           1       0.22      0.57      0.32       250

    accuracy                           0.54      1311
   macro avg       0.53      0.55      0.49      1311
weighted avg       0.72      0.54      0.59      1311


x5
              precision    recall  f1-score   support

           0       0.72      0.53      0.61       926
           1       0.31      0.50      0.38       385

    accuracy                           0.52      1311
   macro avg       0.51      0.52      0.50      1311
weighted avg       0.60      0.52      0.54      1311


x6
              precision    recall  f1-score   support

           0       0.85      0.57      0.69      1077
           1       0.22      0.54      0.31       234

    accuracy                           0.57      1311
   macro avg       0.53      0.56      0.50      1311
weighted avg       0.74      0.57      0.62      1311


x7
              precision    recall  f1-score   support

           0       0.81      0.55      0.65      1021
           1       0.25      0.53      0.34       290

    accuracy                           0.55      1311
   macro avg       0.53      0.54      0.50      1311
weighted avg       0.68      0.55      0.59      1311


x8
              precision    recall  f1-score   support

           0       0.84      0.55      0.66      1069
           1       0.21      0.53      0.30       242

    accuracy                           0.55      1311
   macro avg       0.52      0.54      0.48      1311
weighted avg       0.72      0.55      0.60      1311


x9
              precision    recall  f1-score   support

           0       0.79      0.54      0.64      1006
           1       0.26      0.52      0.35       305

    accuracy                           0.54      1311
   macro avg       0.52      0.53      0.49      1311
weighted avg       0.67      0.54      0.57      1311

【讨论】:

哇,这很有道理。将每一列作为单个标签进行训练,而不是训练整个多标签。好甜~

以上是关于Sklearn Linear SVM 无法在多标签分类中进行训练的主要内容,如果未能解决你的问题,请参考以下文章

Python 2.7 sklearn.svm 警告消息

sklearn中SVM调参说明

LinearSVC 和 SVC(kernel="linear") 有啥区别?

用于 Hard margin SVM 的 Sklearn 内置函数

使用 OneVsRestClassifier 时 sklearn.svm.SVC 的哪个决策函数形状?

Python:如何在多标签类的 SVM 文本分类器算法中找到准确度结果