如何使用 dlib 的 LDA

Posted

技术标签:

【中文标题】如何使用 dlib 的 LDA【英文标题】:How to use dlib's LDA 【发布时间】:2018-04-05 05:44:41 【问题描述】:

我想在我的训练集上拟合 dlib 的 LDA,并将转换应用于训练集和测试集。我编写了以下最小示例来重现该问题。如果删除使用 LDA 的部分,它应该会输出有意义的预测。

#include <iostream>
#include <vector>
#include <dlib/svm.h>

int main() 

    typedef dlib::matrix<float, 2, 1> sample_type;
    typedef dlib::radial_basis_kernel<sample_type> kernel_type;
    dlib::svm_c_trainer<kernel_type> trainer;
    trainer.set_kernel(kernel_type(0.5f));
    trainer.set_c(1.0f);

    std::vector<sample_type> samples_train;
    std::vector<float> labels_train;
    std::vector<sample_type> samples_test;
    std::vector<float> labels_test;

    sample_type sample;
    float label;

    label = -1;
    sample(0) = -1;
    sample(1) = -1;
    samples_train.push_back(sample);
    labels_train.push_back(label);

    label = 1;
    sample(0) = 1;
    sample(1) = 1;
    samples_train.push_back(sample);
    labels_train.push_back(label);

    label = 1;
    sample(0) = 0.5;
    sample(1) = 0.5;
    samples_test.push_back(sample);
    labels_test.push_back(label);

    // Fit LDA on training data
    dlib::matrix<sample_type> X;
    dlib::matrix<sample_type,0,1> mean;
    dlib::compute_lda_transform(X, mean, labels_train);

    // Apply LDA on train data
    for (auto &sample_train : samples_train)
        sample_train = X * sample_train;
    

    // Apply LDA on test data
    for (auto &sample_test : samples_test)
        sample_test = X * sample_test;
    

    auto predictor = trainer.train(samples_train, labels_train);

    std::cout << "Train Sample 1: " << predictor(samples_train[0]) << ", label: " << labels_train[0] << std::endl;
    std::cout << "Train Sample 2: " << predictor(samples_train[1]) << ", label: " << labels_train[1] << std::endl;
    std::cout << "Test Sample: " << predictor(samples_test[0]) << ", label: " << labels_test[0] << std::endl;


错误:

cannot convert 'labels_train' (type 'std::__debug::vector<float>') to type 'const std::__debug::vector<long unsigned int>&'

但如果标签与样本的类型不同,SVM 就会抛出错误。我在 dlib 的 github 存储库中找不到任何示例。

【问题讨论】:

【参考方案1】:

您应该使用两组标签,一组是用于 lda 的 long unsigned 类型,另一组是用于您的 SVM 的 float 类型

【讨论】:

以上是关于如何使用 dlib 的 LDA的主要内容,如果未能解决你的问题,请参考以下文章

如何解释 LDA 组件(使用 sklearn)?

如何检查 dlib 是不是正在使用 GPU?

如何使用 Python 和 DLib 找到线的左侧或右侧的点?

如何在 Xcode C++ 控制台应用程序中使用 dlib

如何使用 conda 在 spyder 上导入 dlib?

如何使用 Dlib 的多目标检测器?