R:如何在 ggplot2 中绘制 svm 的超平面和边距?

Posted

技术标签:

【中文标题】R:如何在 ggplot2 中绘制 svm 的超平面和边距?【英文标题】:R: How to plot the hyperplane and margins of an svm in ggplot2? 【发布时间】:2016-06-19 19:15:58 【问题描述】:

我正在关注 Tibshirani 的 ISL 文本。我正在尝试在 ggplot2 中绘制 SVM 的结果。我可以获得点和支持向量,但我不知道如何获得为 2D 案例绘制的边距和超平面。我用谷歌搜索并检查了 e1071 自述文件。一个通用的、动态的解决方案(适用于各种 SVM 内核、成本等)会很棒。这是我的 MWE:

set.seed(1)
N=20
x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
library(e1071)
library(ggplot2)
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)

df = dat; df
df = cbind(df, sv=rep(0,nrow(df)))
df[svmfit$index,]$sv = 1

ggplot(data=df,aes(x=x.1,y=x.2,group=y,color=y)) +     
    geom_point(aes(shape=factor(sv)))

类似这样的: (来自 Python 的 scikit-learn)

【问题讨论】:

e1071 中已经为svm 定义了一个基本的图形绘制方法。你看plot(svmfit, dat)的结果了吗?你想在 ggplot 中复制它吗? 是的,我正在尝试在 ggplot 中复制它,并为超平面添加线,为 2D 案例 (K=2) 的边距添加虚线。 您可能想接受@user21359 的回答,因为它就像一个魅力 【参考方案1】:

所以你不想绘制支持向量对吗?这是基于plot.svm 源代码的非常基本的内容,适用于您的示例。

https://github.com/cran/e1071/blob/master/R/svm.R

您可以通过查看该源代码来构建更丰富的东西。

library(e1071)
library(ggplot2)
set.seed(1)
N=20
x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)

grid <- expand.grid(seq(min(dat[, 1]), max(dat[, 1]),length.out=100),                                                                                                         
                            seq(min(dat[, 2]), max(dat[, 2]),length.out=100)) 
names(grid) <- names(dat)[1:2]
preds <- predict(svmfit, grid)
df <- data.frame(grid, preds)
ggplot(df, aes(x = x.2, y = x.1, fill = preds)) + geom_tile()

应该输出这个:

将此与plot.svm 输出进行比较:

plot(svmfit, dat)

编辑:

如果你也想重现这些点,我对上面的代码做了一些改动:

cols <- c('1' = 'red', '-1' = 'black')
tiles <- c('1' = 'magenta', '-1' = 'cyan')
shapes <- c('support' = 4, 'notsupport' = 1)
dat$support <- 'notsupport'
dat[svmfit$index, 'support'] <- 'support'

ggplot(df, aes(x = x.2, y = x.1)) + geom_tile(aes(fill = preds)) + 
  scale_fill_manual(values = tiles) +
  geom_point(data = dat, aes(color = y, shape = support), size = 2) +
  scale_color_manual(values = cols) +
  scale_shape_manual(values = shapes) +
  ggtitle('SVM classification plot')

【讨论】:

以上是关于R:如何在 ggplot2 中绘制 svm 的超平面和边距?的主要内容,如果未能解决你的问题,请参考以下文章

9幅图快速理解支持向量机(SVM)的工作原理

九张图快速理解支持向量机(SVM)的工作原理

paper 123: SVM如何避免过拟合

如何在 R 中绘制一类 SVM?

如何在 R 中绘制 SVM 的分类图

如何在R中的ggplot2中绘制组均值的平均值?