sklearn svm 不合适
Posted
技术标签:
【中文标题】sklearn svm 不合适【英文标题】:sklearn svm provides a bad fit 【发布时间】:2021-03-17 13:32:26 【问题描述】:我试图为我的示例数据绘制一个 SVM,但我遇到了一个问题:该图似乎根本不正确,这很奇怪,因为我使用了来自 here 的示例代码(更具体地说, “发生了什么?”部分)。他们的代码对我来说很好,所以我认为问题与我的数据有关。我注意到拟合系数非常小,可以理解的是,这会破坏线条。
这是可重现的代码。
import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm
import matplotlib as mpl
plt.figure(figsize=(5,5))
in_cir = lambda x,y: True if x**2 + y**2 <= 4 else False # Checking if point is in the purple circle
f = lambda x,e: 1.16*x + 0.1 + e # True function
ran = np.arange(-5,6)
lsp = np.linspace(-5,5,170) # X1 axis
np.random.seed(69)
dots = f(lsp,[np.random.normal(0,1.5) for i in lsp]) # X2 axis
blue_dots, pur_dots, lsp1, lsp2 = [], [], [], []
for i, x in zip(dots, lsp):
if in_cir(x,i): pur_dots.append(i); lsp2.append(x) # Getting all purple dots's X1 and X2
else: blue_dots.append(i); lsp1.append(x) # Same for blue ones
plt.scatter(lsp1, blue_dots, color='cornflowerblue')
plt.scatter(lsp2, pur_dots, color='magenta')
plt.xlabel('$X_1$', fontsize=15)
plt.ylabel('$X_2$', fontsize=15)
x, y = np.array(list(zip(lsp, dots))), np.where( np.array([in_cir(x,i) for x,i in zip(lsp,dots)]) == True, 'p','b' )
# On two lines above I made x a 2d array
# of coordinates for each dot
# And y is a list of 'b' if the corresponding
# dot is blue and 'p' otherwise
ft = svm.SVC(kernel='linear', C=1).fit(x, y) # Fitting svc
# Here starts the code from the link
w = ft.coef_[0]
print('w', w) # w components are really small
a = -w[0] / w[1]
xx = np.linspace(-5, 5)
yy = a * xx - (ft.intercept_[0]) / w[1] # This is where it all goes wrong
b = ft.support_vectors_[0]
yy_down = a * xx + (b[1] - a * b[0])
b = ft.support_vectors_[-1]
yy_up = a * xx + (b[1] - a * b[0])
plt.plot(xx, yy, 'k-')
plt.plot(xx, yy_down, 'k--')
plt.plot(xx, yy_up, 'k--')
plt.ylim(-5, 5.5) # To make it interpretable
plt.xlim(-5, 4.5) # the plot will be squished because of
plt.show() # high values if removed
输出是:
如您所见,结果很悲惨。如果有人能解释我做错了什么,我将不胜感激。
编辑:我实际上设法做到了这一点。这是我写的代码:
plt.figure(figsize=(7,7))
np.random.seed(420)
ran = np.arange(-5,6)
st = 1
b, p = np.array([ (-3+np.random.normal(0,st), -2.5+np.random.normal(0,st)) for i in range(25) ]+\
[ (2.5+np.random.normal(0,st), 3.5+np.random.normal(0,st)) for i in range(25) ]), np.array([ (np.random.normal(0,st), np.random.normal(0,st)) for i in range(50) ])
plt.scatter(b[:,0], b[:,1], color='cornflowerblue')
plt.scatter(p[:,0], p[:,1], color='magenta')
plt.xlabel('$X_1$', fontsize=15)
plt.ylabel('$X_2$', fontsize=15)
x, y = np.concatenate( (np.concatenate( (b[:25], p) ), b[-25:]) ), [0]*25 + [1]*50 + [0]*25
ft = svm.SVC(kernel='linear').fit(x, y)
by, bx = np.meshgrid([-5, 6], [-5, 6])
bo = ft.decision_function(np.vstack([by.ravel(), bx.ravel()]).T).reshape(bx.shape).T
xx, yy = np.meshgrid(np.arange(-5.1, 4.6, 0.01),
np.arange(-5.1, 5.6, 0.01))
Z = ft.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
C = plt.contourf(xx, yy, Z,colors='none', hatches=['.'])
colors=['cornflowerblue', 'magenta']
for j, collection in enumerate(C.collections):
if j == 0: collection.set_edgecolor(colors[0])
else: collection.set_edgecolor(colors[1])
plt.contour(bx, by, bo, colors='0', levels=[-1, 0, 1], linestyles=['--', '-', '--'])
plt.ylim(-5, 5.5)
plt.xlim(-5, 4.5)
plt.show()
它的结果是:
【问题讨论】:
【参考方案1】:您正在尝试使用线性分类器 data which is not linearly separable 进行分隔(即,您无法绘制一条直线来分隔两组)。您可以使用另一个内核,例如 RBF:
import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm
from mlxtend.plotting import plot_decision_regions
plt.figure(figsize=(5,5))
in_cir = lambda x,y: True if x**2 + y**2 <= 4 else False # Checking if point is in the purple circle
f = lambda x,e: 1.16*x + 0.1 + e # True function
ran = np.arange(-5,6)
lsp = np.linspace(-5,5,170) # X1 axis
np.random.seed(69)
dots = f(lsp,[np.random.normal(0,1.5) for i in lsp]) # X2 axis
blue_dots, pur_dots, lsp1, lsp2 = [], [], [], []
for i, x in zip(dots, lsp):
if in_cir(x,i): pur_dots.append(i); lsp2.append(x) # Getting all purple dots's X1 and X2
else: blue_dots.append(i); lsp1.append(x) # Same for blue ones
x, y = np.array(list(zip(lsp, dots))), np.where(np.array([in_cir(x,i) for x,i in zip(lsp,dots)]), 'p','b')
y[y == 'b'] = 0 # replacing letters with integers as the plot_decision_regions function accepts only integers
y[y == 'p'] = 1
y = y.astype(int)
ft = svm.SVC(kernel='rbf', C=1).fit(x, y) # Fitting svc
plot_decision_regions(X=x,
y=y,
clf=ft,
legend=2)
plt.show()
【讨论】:
谢谢,这是真的。但是,我仍然设法做到了。如果有人感兴趣,我会更新帖子。以上是关于sklearn svm 不合适的主要内容,如果未能解决你的问题,请参考以下文章