使用神经网络和 ROCR 包绘制神经网络曲线
Posted
技术标签:
【中文标题】使用神经网络和 ROCR 包绘制神经网络曲线【英文标题】:Plot a Neural Net Curve Using neuralnet and ROCR package 【发布时间】:2017-08-27 15:12:44 【问题描述】:这里我有一个分类任务,我需要使用神经网络和 ROCR 包。问题是我在使用预测功能时收到错误消息。
这是我的代码:
#load packages
require(neuralnet)
library(ROCR)
#create data set
train<-read.table(file="train.txt",header=TRUE,sep=",")
test<- read.table(file="test.txt",header=TRUE,sep=",")
#build model and make predictions
nn.sag <- neuralnet(Type ~ Area+Perimeter+Compactness+Length+Width+Asymmetry+Groove, data = train, hidden = 5, algorithm = "sag", err.fct = "sse", linear.output = FALSE)
prob = compute(nn.sag, test[, -ncol(test)] )
prob.result <- prob$net.result
nn.pred = prediction(prob.result, test$Type)
pref <- performance(nn.pred, "tpr", "fpr")
plot(pref)
在这里,我收到了“预测”功能的错误消息: '$ 运算符对原子向量无效'
数据集看起来像(这里只有训练数据集):
Area,Perimeter,Compactness,Length,Width,Asymmetry,Groove,Type
14.8,14.52,0.8823,5.656,3.288,3.112,5.309,1
14.79,14.52,0.8819,5.545,3.291,2.704,5.111,1
14.99,14.56,0.8883,5.57,3.377,2.958,5.175,1
19.14,16.61,0.8722,6.259,3.737,6.682,6.053,0
15.69,14.75,0.9058,5.527,3.514,1.599,5.046,1
14.11,14.26,0.8722,5.52,3.168,2.688,5.219,1
13.16,13.55,0.9009,5.138,3.201,2.461,4.783,1
16.16,15.33,0.8644,5.845,3.395,4.266,5.795,0
15.01,14.76,0.8657,5.789,3.245,1.791,5.001,1
14.11,14.1,0.8911,5.42,3.302,2.7,5,1
17.98,15.85,0.8993,5.979,3.687,2.257,5.919,0
21.18,17.21,0.8989,6.573,4.033,5.78,6.231,0
14.29,14.09,0.905,5.291,3.337,2.699,4.825,1
14.59,14.28,0.8993,5.351,3.333,4.185,4.781,1
11.42,12.86,0.8683,5.008,2.85,2.7,4.607,1
12.11,13.47,0.8392,5.159,3.032,1.502,4.519,1
15.6,15.11,0.858,5.832,3.286,2.725,5.752,0
15.38,14.66,0.899,5.477,3.465,3.6,5.439,0
18.94,16.49,0.875,6.445,3.639,5.064,6.362,0
12.36,13.19,0.8923,5.076,3.042,3.22,4.605,1
14.01,14.29,0.8625,5.609,3.158,2.217,5.132,1
17.12,15.55,0.8892,5.85,3.566,2.858,5.746,0
15.78,14.91,0.8923,5.674,3.434,5.593,5.136,1
16.19,15.16,0.8849,5.833,3.421,0.903,5.307,1
14.43,14.4,0.8751,5.585,3.272,3.975,5.144,1
13.8,14.04,0.8794,5.376,3.155,1.56,4.961,1
14.46,14.35,0.8818,5.388,3.377,2.802,5.044,1
18.59,16.05,0.9066,6.037,3.86,6.001,5.877,0
18.75,16.18,0.8999,6.111,3.869,4.188,5.992,0
15.49,14.94,0.8724,5.757,3.371,3.412,5.228,1
12.73,13.75,0.8458,5.412,2.882,3.533,5.067,1
13.5,13.85,0.8852,5.351,3.158,2.249,5.176,1
14.38,14.21,0.8951,5.386,3.312,2.462,4.956,1
14.86,14.67,0.8676,5.678,3.258,2.129,5.351,1
18.45,16.12,0.8921,6.107,3.769,2.235,5.794,0
17.32,15.91,0.8599,6.064,3.403,3.824,5.922,0
20.2,16.89,0.8894,6.285,3.864,5.173,6.187,0
20.03,16.9,0.8811,6.493,3.857,3.063,6.32,0
18.14,16.12,0.8772,6.059,3.563,3.619,6.011,0
13.99,13.83,0.9183,5.119,3.383,5.234,4.781,1
15.57,15.15,0.8527,5.92,3.231,2.64,5.879,0
16.2,15.27,0.8734,5.826,3.464,2.823,5.527,1
20.97,17.25,0.8859,6.563,3.991,4.677,6.316,0
14.16,14.4,0.8584,5.658,3.129,3.072,5.176,1
13.45,14.02,0.8604,5.516,3.065,3.531,5.097,1
15.5,14.86,0.882,5.877,3.396,4.711,5.528,1
16.77,15.62,0.8638,5.927,3.438,4.92,5.795,0
12.74,13.67,0.8564,5.395,2.956,2.504,4.869,1
14.88,14.57,0.8811,5.554,3.333,1.018,4.956,1
14.28,14.17,0.8944,5.397,3.298,6.685,5.001,1
14.34,14.37,0.8726,5.63,3.19,1.313,5.15,1
14.03,14.16,0.8796,5.438,3.201,1.717,5.001,1
19.11,16.26,0.9081,6.154,3.93,2.936,6.079,0
14.52,14.6,0.8557,5.741,3.113,1.481,5.487,1
18.43,15.97,0.9077,5.98,3.771,2.984,5.905,0
18.81,16.29,0.8906,6.272,3.693,3.237,6.053,0
13.78,14.06,0.8759,5.479,3.156,3.136,4.872,1
14.69,14.49,0.8799,5.563,3.259,3.586,5.219,1
18.85,16.17,0.9056,6.152,3.806,2.843,6.2,0
12.88,13.5,0.8879,5.139,3.119,2.352,4.607,1
12.78,13.57,0.8716,5.262,3.026,1.176,4.782,1
14.33,14.28,0.8831,5.504,3.199,3.328,5.224,1
19.46,16.5,0.8985,6.113,3.892,4.308,6.009,0
19.38,16.72,0.8716,6.303,3.791,3.678,5.965,0
15.26,14.85,0.8696,5.714,3.242,4.543,5.314,1
20.24,16.91,0.8897,6.315,3.962,5.901,6.188,0
19.94,16.92,0.8752,6.675,3.763,3.252,6.55,0
20.71,17.23,0.8763,6.579,3.814,4.451,6.451,0
16.17,15.38,0.8588,5.762,3.387,4.286,5.703,0
13.02,13.76,0.8641,5.395,3.026,3.373,4.825,1
16.53,15.34,0.8823,5.875,3.467,5.532,5.88,0
13.89,14.02,0.888,5.439,3.199,3.986,4.738,1
18.98,16.57,0.8687,6.449,3.552,2.144,6.453,0
17.08,15.38,0.9079,5.832,3.683,2.956,5.484,1
15.03,14.77,0.8658,5.702,3.212,1.933,5.439,1
16.14,14.99,0.9034,5.658,3.562,1.355,5.175,1
18.65,16.41,0.8698,6.285,3.594,4.391,6.102,0
20.1,16.99,0.8746,6.581,3.785,1.955,6.449,0
17.99,15.86,0.8992,5.89,3.694,2.068,5.837,0
15.88,14.9,0.8988,5.618,3.507,0.7651,5.091,1
13.22,13.84,0.868,5.395,3.07,4.157,5.088,1
18.3,15.89,0.9108,5.979,3.755,2.837,5.962,0
19.51,16.71,0.878,6.366,3.801,2.962,6.185,0
【问题讨论】:
【参考方案1】:prediction()
函数在 R 中的神经网络和 ROCR 包中都可用。所以不要同时加载这两个包。首先加载神经网络,训练您的模型,然后使用detach()
将其分离,然后加载 ROCR 包。试试下面的代码:
#load packages
require(neuralnet)
#create data set
train<-read.table(file="train.txt",header=TRUE,sep=",")
test<- read.table(file="test.txt",header=TRUE,sep=",")
#build model and make predictions
nn.sag <- neuralnet(Type ~ Area+Perimeter+Compactness+Length+Width+Asymmetry+Groove, data = train, hidden = 5, algorithm = "sag", err.fct = "sse", linear.output = FALSE)
prob = compute(nn.sag, test[, -ncol(test)] )
prob.result <- prob$net.result
detach(package:neuralnet,unload = T)
library(ROCR)
nn.pred = prediction(prob.result, test$Type)
pref <- performance(nn.pred, "tpr", "fpr")
plot(pref)
【讨论】:
【参考方案2】:或者干脆使用ROCR::prediction(prediction(prob.result, test$Type))
用于选择正确的包。
【讨论】:
以上是关于使用神经网络和 ROCR 包绘制神经网络曲线的主要内容,如果未能解决你的问题,请参考以下文章