Sklearn Linear SVM 无法在多标签分类中进行训练
Posted
技术标签:
【中文标题】Sklearn Linear SVM 无法在多标签分类中进行训练【英文标题】:Sklearn Linear SVM cannot train in multilabel classification 【发布时间】:2021-05-19 00:52:46 【问题描述】:我想使用以下代码训练具有多标签分类的线性 SVM:
from sklearn.svm import LinearSVC
from sklearn.multioutput import MultiOutputClassifier
import numpy as np
data = np.loadtxt('tictac_multi.txt')
X = data[:,:9]
y = data[:,9:]
clf = MultiOutputClassifier(LinearSVC(random_state=0, tol=1e-5, C=100, penalty='l2',max_iter=2000))
clf.fit(X, y)
print(clf.score(X, y))
数据集可以在这里找到https://www.connellybarnes.com/work/class/2016/deep_learning_graphics/proj1/tictac_multi.txt
我尝试调整不同的参数,例如 C、tol、max_iter 等。线性 SVM 模型仍然不能很好地训练。无论我调整任何参数,训练精度仍然低于 0.01...
以上代码的输出是:
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
0.011601282246985194
与当前代码相比,精度为 0.0116。
【问题讨论】:
您好!您提供的链接无效。您能否提供输入数据的摘录并更好地描述您的分类问题(输入是什么、类别等)? @PieCot 感谢您与我们联系。问题是多标签二元问题。每个输入是一个包含 9 个值的向量,这些值将是 0、1 或 -1。每个输出也是一个有 9 个值的向量,值将是 0 或 1。 例如,输入向量 = [0 1 1 0 0 -1 0 0 0] 输出向量=[1 0 0 1 1 0 1 1 1]。 @PieCot 我更新了这条线。你能不能再试一次。非常感谢。 我认为您的数据不是线性可分的。在这种情况下,线性 SVM 将无济于事。 【参考方案1】:它看起来像一个“井字游戏”数据集(从文件名和格式来看)。
假设数据集的前 9 列提供了对游戏特定时刻 9 个单元格的描述,而其他 9 列表示对应于好棋步的单元格,您可以逐个单元格训练分类器,在为了预测一个单元格是否是一个好的移动。
因此,您实际上需要训练 9 个二元分类器,而不是一个。我基于这个想法在下面的代码中勾勒了一个非常简单的方法。在训练/测试 (80/20) 中拆分数据集后,从简单的交叉验证开始:
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_validate, train_test_split
from sklearn.metrics import classification_report
import pandas as pd
# Load data, creating a Dataframe holding input and outputs
df = pd.read_csv('tictac_multi.txt', sep=' ', header=None)[list(range(18))].copy()
df.columns = pd.MultiIndex.from_product((('input', 'output'), [f'xi' for i in range(1, 10)]))
# split dataset 80/20 (also shuffle it)
X_train, X_test, y_train, y_test = train_test_split(df['input'].values, df['output'].values, test_size=0.2, random_state=42)
# Get scores from cross validation
scores =
s: cross_validate(
LinearSVC(random_state=0, dual=False, class_weight='balanced', tol=1e-5),
X_train, y_train[:, i], cv=5, scoring=['balanced_accuracy', 'precision', 'recall', 'f1_weighted'],
n_jobs=-1,
) for i, (s, clf) in enumerate(sorted(clfs.items()))
如您所见,我为分类器使用了一些非默认选项 (dual=False, class_weight='balanced'
):它们只是有根据的猜测,您应该进行更多调查以更好地理解数据和问题,然后寻找模型的最佳参数(例如,网格搜索)。
这里是分数:
'x1': 'fit_time': array([0.01000571, 0.00814652, 0.00937247, 0.00622296, 0.00536656]),
'score_time': array([0.01159358, 0.00597596, 0.00835085, 0.00647163, 0.00619125]),
'test_balanced_accuracy': array([0.52209841, 0.51820565, 0.53743952, 0.55455645, 0.53620968]),
'test_precision': array([0.25454545, 0.25 , 0.26611227, 0.27659574, 0.26295585]),
'test_recall': array([0.5060241 , 0.52016129, 0.51612903, 0.5766129 , 0.55241935]),
'test_f1_weighted': array([0.56543736, 0.55328701, 0.58232694, 0.57711117, 0.56292617]),
'x2': 'fit_time': array([0.00737047, 0.00885296, 0.00616217, 0.00707698, 0.0071764 ]),
'score_time': array([0.00650406, 0.00595641, 0.00623679, 0.00636506, 0.00567913]),
'test_balanced_accuracy': array([0.57367382, 0.5342687 , 0.55287658, 0.56565243, 0.57909451]),
'test_precision': array([0.22520661, 0.20041754, 0.21073559, 0.22037422, 0.23175966]),
'test_recall': array([0.5828877 , 0.51336898, 0.56684492, 0.56684492, 0.57446809]),
'test_f1_weighted': array([0.6183652 , 0.60068273, 0.59707974, 0.61584554, 0.63060231]),
'x3': 'fit_time': array([0.0067966 , 0.00759745, 0.00617337, 0.00679278, 0.00650382]),
'score_time': array([0.00605631, 0.00537109, 0.00551271, 0.00665474, 0.00649571]),
'test_balanced_accuracy': array([0.52683332, 0.54103562, 0.56227539, 0.53312408, 0.51986383]),
'test_precision': array([0.25502008, 0.26639344, 0.28367347, 0.26035503, 0.25 ]),
'test_recall': array([0.51626016, 0.52845528, 0.56275304, 0.53441296, 0.53036437]),
'test_f1_weighted': array([0.56805171, 0.58208858, 0.59506983, 0.56776364, 0.55079222]),
'x4': 'fit_time': array([0.00649667, 0.00767159, 0.00802064, 0.00769711, 0.00611663]),
'score_time': array([0.00572419, 0.00529647, 0.00616765, 0.00592041, 0.00609517]),
'test_balanced_accuracy': array([0.53369766, 0.57259312, 0.57644138, 0.55746825, 0.51877354]),
'test_precision': array([0.19791667, 0.22290389, 0.22540984, 0.21489362, 0.18930041]),
'test_recall': array([0.51351351, 0.58602151, 0.59139785, 0.54301075, 0.49462366]),
'test_f1_weighted': array([0.6005693 , 0.615313 , 0.61784599, 0.61784823, 0.58924053]),
'x5': 'fit_time': array([0.00650501, 0.005898 , 0.00682783, 0.00629449, 0.00635648]),
'score_time': array([0.00553894, 0.0059135 , 0.00625896, 0.00583744, 0.00580502]),
'test_balanced_accuracy': array([0.51108635, 0.50499149, 0.52183641, 0.53230958, 0.51296946]),
'test_precision': array([0.30185185, 0.29735234, 0.31163708, 0.322 , 0.30522088]),
'test_recall': array([0.53094463, 0.47557003, 0.51465798, 0.52272727, 0.49350649]),
'test_f1_weighted': array([0.5248707 , 0.53861778, 0.54612005, 0.55679291, 0.54217533]),
'x6': 'fit_time': array([0.00703621, 0.00908065, 0.00665092, 0.00619102, 0.00814819]),
'score_time': array([0.00568652, 0.00626183, 0.00584817, 0.00574327, 0.00552726]),
'test_balanced_accuracy': array([0.55457928, 0.55569106, 0.50701258, 0.53690769, 0.56919396]),
'test_precision': array([0.2145749 , 0.21621622, 0.18480493, 0.20416667, 0.22540984]),
'test_recall': array([0.56084656, 0.55026455, 0.47619048, 0.51851852, 0.57894737]),
'test_f1_weighted': array([0.60241544, 0.61008882, 0.5813744 , 0.60080544, 0.6130977 ]),
'x7': 'fit_time': array([0.0070405 , 0.00908256, 0.00702643, 0.00635576, 0.00632381]),
'score_time': array([0.00546646, 0.00674367, 0.00542998, 0.00671315, 0.00549483]),
'test_balanced_accuracy': array([0.53124816, 0.52187224, 0.54180051, 0.57438252, 0.52764072]),
'test_precision': array([0.27054108, 0.26235741, 0.27659574, 0.30364372, 0.26824034]),
'test_recall': array([0.52325581, 0.53488372, 0.55642023, 0.58365759, 0.48638132]),
'test_f1_weighted': array([0.56745684, 0.54860915, 0.56677092, 0.5996452 , 0.57954721]),
'x8': 'fit_time': array([0.00761437, 0.00997519, 0.006984 , 0.00623441, 0.00683069]),
'score_time': array([0.00540686, 0.00635052, 0.00645804, 0.00535131, 0.00548935]),
'test_balanced_accuracy': array([0.51471322, 0.56996108, 0.52712724, 0.5443143 , 0.55319282]),
'test_precision': array([0.18661258, 0.22292994, 0.192607 , 0.20408163, 0.20874751]),
'test_recall': array([0.49462366, 0.56451613, 0.53513514, 0.54054054, 0.56756757]),
'test_f1_weighted': array([0.58328382, 0.62374708, 0.57815794, 0.60051373, 0.59779516]),
'x9': 'fit_time': array([0.00723267, 0.0069263 , 0.00828266, 0.00672913, 0.00750995]),
'score_time': array([0.00545311, 0.00556946, 0.00732398, 0.0056181 , 0.00555682]),
'test_balanced_accuracy': array([0.53490307, 0.55281703, 0.58447809, 0.52272419, 0.54294236]),
'test_precision': array([0.26388889, 0.27868852, 0.29811321, 0.25506073, 0.27198364]),
'test_recall': array([0.53413655, 0.54618474, 0.63453815, 0.5060241 , 0.532 ]),
'test_f1_weighted': array([0.56987212, 0.58922553, 0.59075641, 0.56631422, 0.5819019 ])
如您所见,它们不是很好,但远非 0。
现在,在整个训练数据集上训练模型,并在测试数据上评估性能:
def train_clfs(clfs, X, y):
return s: clf.fit(X, y[:, i]) for i, (s, clf) in enumerate(sorted(clfs.items()))
def get_predictions(clfs, inp):
return s: clf.predict(inp) for s, clf in clfs.items()
# Train the classifiers
clfs = s: LinearSVC(random_state=0, dual=False, class_weight='balanced', tol=1e-5) for s in sorted(df['output'].columns)
clfs = train_clfs(clfs, X_train, y_train)
# Try them on the test values
pred = get_predictions(clfs, X_test)
# Get the classification report for each classifier
cl_report = s: classification_report(y_test[:, i], p) for i, (s, p) in enumerate(sorted(pred.items()))
这里是表演:
x1
precision recall f1-score support
0 0.76 0.52 0.62 988
1 0.25 0.49 0.33 323
accuracy 0.51 1311
macro avg 0.50 0.51 0.48 1311
weighted avg 0.63 0.51 0.55 1311
x2
precision recall f1-score support
0 0.87 0.56 0.68 1086
1 0.22 0.58 0.31 225
accuracy 0.57 1311
macro avg 0.54 0.57 0.50 1311
weighted avg 0.75 0.57 0.62 1311
x3
precision recall f1-score support
0 0.79 0.50 0.61 998
1 0.26 0.57 0.36 313
accuracy 0.52 1311
macro avg 0.53 0.54 0.49 1311
weighted avg 0.66 0.52 0.55 1311
x4
precision recall f1-score support
0 0.84 0.54 0.65 1061
1 0.22 0.57 0.32 250
accuracy 0.54 1311
macro avg 0.53 0.55 0.49 1311
weighted avg 0.72 0.54 0.59 1311
x5
precision recall f1-score support
0 0.72 0.53 0.61 926
1 0.31 0.50 0.38 385
accuracy 0.52 1311
macro avg 0.51 0.52 0.50 1311
weighted avg 0.60 0.52 0.54 1311
x6
precision recall f1-score support
0 0.85 0.57 0.69 1077
1 0.22 0.54 0.31 234
accuracy 0.57 1311
macro avg 0.53 0.56 0.50 1311
weighted avg 0.74 0.57 0.62 1311
x7
precision recall f1-score support
0 0.81 0.55 0.65 1021
1 0.25 0.53 0.34 290
accuracy 0.55 1311
macro avg 0.53 0.54 0.50 1311
weighted avg 0.68 0.55 0.59 1311
x8
precision recall f1-score support
0 0.84 0.55 0.66 1069
1 0.21 0.53 0.30 242
accuracy 0.55 1311
macro avg 0.52 0.54 0.48 1311
weighted avg 0.72 0.55 0.60 1311
x9
precision recall f1-score support
0 0.79 0.54 0.64 1006
1 0.26 0.52 0.35 305
accuracy 0.54 1311
macro avg 0.52 0.53 0.49 1311
weighted avg 0.67 0.54 0.57 1311
【讨论】:
哇,这很有道理。将每一列作为单个标签进行训练,而不是训练整个多标签。好甜~以上是关于Sklearn Linear SVM 无法在多标签分类中进行训练的主要内容,如果未能解决你的问题,请参考以下文章
LinearSVC 和 SVC(kernel="linear") 有啥区别?
用于 Hard margin SVM 的 Sklearn 内置函数