使用 LIBSVM 的 one vs rest 多类分类。 matlab
Posted
技术标签:
【中文标题】使用 LIBSVM 的 one vs rest 多类分类。 matlab【英文标题】:One vs rest multiclass classification using LIBSVM. matlab 【发布时间】:2014-06-14 23:11:24 【问题描述】:我正在尝试使用 LIBSVM 实现 one vs rest 多类分类。
这个链接很有用http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/ovr_multiclass/ 但我在函数 'ovrpredict()' 中遇到错误。
功能如下:
function [pred, ac, decv] = ovrpredict(y, x, model)
labelSet = model.labelSet;
labelSetSize = length(labelSet);
models = model.models;
decv= zeros(size(y, 1), labelSetSize);
for i=1:labelSetSize
[l,a,d] = svmpredict(double(y == labelSet(i)), x, modelsi);
decv(:, i) = d * (2 * modelsi.Label(1) - 1); % ERROR IN THIS LINE
end
[tmp,pred] = max(decv, [], 2);
pred = labelSet(pred);
ac = sum(y==pred) / size(x, 1);
我得到的错误信息是Reference to non-existent field 'Label'
。
任何建议都会很有帮助。
编辑 1
用于调用函数的代码:\
[trainY trainX]=libsvmread('libfacecombine.train');
[testY testX]=libsvmread('libfacetest.train');
model=ovrtrain(trainY,trainX,'-c 8 -g 4');
[~,accuracy,~]=ovrpredict(testY,testX,model);
训练和测试数据即“libfacecombine.train”和“libfacetest.train”是从.csv文件中获取的:
f1=createdabase(f); % where createdatabase is a function to read various images from a folder and arrange into 1D array
[sig1 mn1]=pcam(f1); % where pcam is a function to find 'pca'(sig1) and 'mean'(mn1) of the data
%labelling is done this way:
%Positive class
label=[];
for i=1:length(sig1)
for j=1:1
label(i,j)=+1;
end
end
csvwrite('face1.csv',[label sig1]);
%Negative class
label1=[];
for i=1:length(sig2) % sig2 obtained in same way as sig1
for j=1:1
label1(i,j)=-1;
end
end
csvwrite('face2.csv',[label sig2]);
使用“附加”模式将这两个文件合并并转换为 .train 文件。 测试数据也是如此。
编辑 2
我有 5 节课。并且标记为: 第 1 类:+1 包含来自 Face 1 的 4 张图像的特征,-1 包含来自非 Face 1 的 4 张图像(人脸 2、3、4 和 5)的特征。 2类:+2包含来自Face 2的4张图像的特征,-2包含来自非Face 2(面1、3、4和5)的4张图像的特征...... 5类:+5包含来自Face 2的4张图像的特征面 5 和 -5 包含来自非面 5(面 1、2、3 和 4)的 4 张图像的特征。所有这些特征连同标签都按上述顺序写入 .csv 文件,然后转换为 .train 格式。因此我获得了训练文件。
对于测试图像,我拍摄一张人脸 1 的图像并给出其真实标签,即 +1 并写入 .csv 文件,然后转换为 .train。因此我获得了测试文件。当我运行程序时,我会得到如下结果:
Accuracy=92%(12/13)classification;
Accuracy=61%(8/13)classification;
Accuracy=100%(13/13)classification;
Accuracy=100%(13/13)classification;
Accuracy=100%(13/13)classification;
Accuracy=100%(13/13)classification;
当我只有 5 个类时,为什么我会获得 6 个准确度值?
【问题讨论】:
你是如何调用函数的?你需要展示你的代码和你训练模型的部分。如果您也描述您的数据也会有所帮助。顺便说一句,我已经在之前的答案中展示了如何使用 libsvm 执行一对多分类,请参阅 here 和 here。 你能确认你得到的错误信息是Undefined variable Label
吗?当您不小心引用了一个不存在的变量时,这不是您会得到的确切文本,这让我认为您可能记错了错误。此外,我不确定您如何从该行中获取该错误。也许是Reference to non-existent field 'Label'
?
@Amro:我的数据集是人脸图像。我在做人脸识别。我已经看过你以前的答案,但都使用了fisheriris,而且我标记数据的方式与这个数据集不同。我将返回调用函数的代码。
@SamRoberts:是的,错误消息是“引用不存在的字段‘标签’”。对不起,我记错了。我该如何克服它?
@Sid:我不明白,根据您上面的代码,您有两个类(正面和负面实例)。那么,当您只有二元分类时,为什么需要执行 one-vs-rest 呢?此外,您仍然没有提供有关数据的足够信息:您有多少特征?类标签的类型是什么?你如何从 CSV 文件到 libsvm 稀疏表示?如果可能,您应该发布一小部分数据样本和可重现的示例代码..
【参考方案1】:
虽然我不确定如何解决这个问题,但以下是我的观察:
在ovrtrain
中创建了一个模型,其中唯一的添加字段为labelSet
。
稍后在overpredict
代码尝试读出label
。
所以,我看到了三个选项:
-
模型创建错误
中间的一步被遗忘了
模型使用错误
编辑:这部分可能不是解决方案,正如评论中提到的那样
我没有足够的信息知道哪些适用,但如果 代码中没有提供示例,您可以尝试替换
Label
labelSet
无处不在。
请注意,如果是第 1 点或第 3 点,此代码应始终提供错误并且永远不会正常运行。
【讨论】:
我将尝试在任何地方用 LabelSet 替换 Label,然后将结果回复给您。谢谢! -1 不正确;model
是 svmtrain
作为创建 SVM 模型的结果返回的结构(多次调用,全部分组在另一个结构中)。现在由于训练是以一对一的方式进行的(1-vs-not1、2-vs-not2、..),每次Label
字段将只包含两个类标签[0 1]
(值1对应于类i
中的实例,其余为0)。另一方面,labelSet
是从整个标签集构建的,并包含所有可能的类值 ([1 2 3 ..]
)。就目前而言,这个问题没有提供足够的信息来充分回答..
我确信@Amro 是正确的,但由于我不知道如何使答案更好,我决定在我的建议旁边放置一个警告,并保持更一般的部分可以帮助您从结构上解决问题。
@DennisJaheruddin:(1)我刚刚给出了 2 类 SVM 的示例,我想将其扩展到 n 类 SVM,这就是为什么要尝试使用上述功能。(2)和我已经提取了 PCA 特征(采用了给出最大方差的组件)。(3)类标签的形式为模型 1(+1,-1),模型 2(+2,-2)...... on. 在模型 1 中,正类 (+1) 将具有面部 1 训练图像,负类 (-1) 将具有所有其他面部图像(面部 1 图像除外),
在模型 2 中,正类 (+2) 将具有人脸 2 训练图像,负类 (-2) 将具有所有其他人脸图像(除了人脸 2 图像)等等。然后需要将测试图像与所有这些模型进行比较,与测试图像匹配精度最高的模型将是测试人脸所属的模型(属于该模型的正类)。【参考方案2】:
在互联网上找到解决此类问题的方法很棘手,但让我们试一试。这篇文章由问题而不是答案组成。但是,我相信,如果您全部回答了这些问题,您会在没有进一步帮助的情况下找到您的错误——或者至少找到了 90% 的错误。
所有这些步骤都适用于调试任何类型的 MATLAB 程序。
确保工作空间整洁
工作区中的旧版本变量可能难以调试。变量名中的拼写错误可能会导致意外使用旧版本。在调试开始时使用 clear
清除工作区。
运行一个更简单的程序
我已经编译了 libsvm,添加了 ovr_multiclass
插件,我可以成功运行以下我编写的示例脚本:
clear
% random train and test data
trainX = rand(10, 4);
trainY = randi(4, 10, 1);
testX = rand(10, 4);
testY = randi(4, 10, 1);
model=ovrtrain(trainY,trainX,'-c 8 -g 4');
[~,accuracy,~]=ovrpredict(testY,testX,model);
你能运行这个,还是你得到和以前一样的错误? 这种最小工作示例对于调试非常有用。 使用用户生成的少量数据可确保错误不会来自意外来源,并有助于缩小原因范围。
检查models
元胞数组
您声明此错误正在发生:
decv(:, i) = d * (2 * modelsi.Label(1) - 1); % ERROR IN THIS LINE
该行的关键部分是modelsi.Label(1)
。这是采用单元格数组models
,并提取i
th 项。这个i
th 项应该是一个结构,具有一个名为Label
的字段。 Label
应该是一个非空数组,可以从中提取第一个元素。 models
元胞数组是model
结构中的一个字段,它作为第三个参数传入ovrpredict
。
运行上面我非常简单的测试脚本后,我在 MATLAB 命令窗口中运行以下诊断程序:
>> models = model.models
models =
[1x1 struct]
[1x1 struct]
[1x1 struct]
[1x1 struct]
>> models1
ans =
Parameters: [5x1 double]
nr_class: 2
totalSV: 6
rho: -1.2122
Label: [2x1 double]
sv_indices: [6x1 double]
ProbA: []
ProbB: []
nSV: [2x1 double]
sv_coef: [6x1 double]
SVs: [6x4 double]
>> models1.Label
ans =
0
1
如果你做同样的事情,你会得到同样的结果吗?如果没有,请将您的输出发布到这些命令的编辑中。
函数中的调试
如果模型看起来不错,但您仍然收到错误,请通过在命令窗口中键入以下内容来打开错误调试器:
dbstop if error
当 MATLAB 在函数中遇到错误时,它现在会暂停并允许您检查所有变量。
再次运行您的程序(或者我的程序,如果您在我发布的最小工作示例中遇到错误)。发生错误时程序应暂停。您的命令提示符应该从 >>
更改为 K>>
。
执行上述步骤以示例model
单元格数组。然后尝试复制、粘贴和运行在命令窗口中出现错误的代码行。尝试运行它的一小部分,例如modelsi.Label(1)
然后2 * modelsi.Label(1) - 1
。
输入dbquit
退出调试器,输入dbclear if error
关闭错误自动调试。
(另请参阅下面关于错误消息的问题 - 确保错误确实发生在您认为的位置!)
编辑:一些额外的问题
您使用的是哪个版本的 MATLAB?例如R2013a
如果您在 MATLAB 命令行中键入 which ovrpredict
,您是否看到了所需文件的路径? (即 ovrpredict.m 保存在计算机上的正确路径)
您的ovrpredict.m
文件(如which ovrpredict
所指)是否包含您在问题中粘贴的内容?您得到的错误表明它们可能是一个微小的差异,例如一个额外的空间。
【讨论】:
非常感谢。通过调试该功能,我发现出了什么问题。我之前在尝试更正错误时更改了变量名称,并且模型、模型、标签集、标签的某处都混淆了。再次感谢您的函数调试技术,它真的很有帮助! 该代码非常适用于小型数据集,但对于我正在使用的 caltech 数据库,它会不停地进行训练(持续了长达两个小时!显然出了点问题)。有什么建议吗? 加州理工学院的数据库有多大?如果你对加州理工学院的数据库进行二次抽样,使其大小达到 100 分之一,会发生什么?如果数据量大且模型复杂,训练模型可能需要很长时间(有时超过一天!)。这个新问题可能超出了这个问题的范围,可能需要一个自己的问题...... 比尔,是的,因为数据库的大小,我把它分成更小的集合,可以获得结果。现在的问题是,我真的不明白结果。 我有 5 节课。并且标记为: 1 类:+1 包含来自 4 张人脸 1 图像的特征,-1 包含来自 4 张非人脸 1 图像(人脸 2、3、4 和 5)的特征。 2类:+2包含来自Face 2的4张图像的特征,-2包含来自非Face 2(面1、3、4和5)的4张图像的特征...... 5类:+5包含来自Face 2的4张图像的特征面 5 和 -5 包含来自非面 5(面 1、2、3 和 4)的 4 张图像的特征。所有这些特征连同标签都按上述顺序写入 .csv 文件,然后转换为 .train 格式。因此我获得了训练文件。以上是关于使用 LIBSVM 的 one vs rest 多类分类。 matlab的主要内容,如果未能解决你的问题,请参考以下文章
多分类学习方法One vs. RestOne vs. OneMany vs. Many多输出分类
用于多类分类的 SVM(one-vs-all)中的置信度估计