c_cpp dlib_LearnSVM
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了c_cpp dlib_LearnSVM相关的知识,希望对你有一定的参考价值。
#include "pch.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
#include <dlib/time_this.h>
#include <dlib/algs.h>
#include <vector>
#include <iostream>
#include <cerrno>
using namespace dlib;
using namespace std;
typedef matrix<double, 0, 1> sample_type;
typedef matrix<double, 0, 1> sample_type; // variable column vector
typedef one_vs_one_trainer<any_trainer<sample_type>, double> ovo_trainer;
typedef radial_basis_kernel<sample_type> rbf_kernel_type;
class Learn {
public:
Learn() { }
virtual ~Learn() { }
virtual void train() { }
virtual void save(std::string pathname) { }
virtual void load(std::string pathname) { }
inline sample_type vectorToSample(std::vector<double> sample_);
};
class LearnSupervised : public Learn {
public:
LearnSupervised() : Learn() { }
void addSample(std::vector<double> sample, double label);
void addSample(sample_type sample, double label);
virtual double predict(std::vector<double> &sample) { return 0; }
virtual double predict(sample_type &sample) { return 0; }
void clearTrainingInstances();
protected:
std::vector<sample_type> samples;
std::vector<double> labels;
};
class LearnSVM : public LearnSupervised {
public:
LearnSVM();
~LearnSVM();
void train();
void trainWithGridParameterSearch();
double predict(std::vector<double> & sample);
double predict(sample_type & sample);
void save(string path);
void load(string path);
private:
ovo_trainer trainer;
krr_trainer<rbf_kernel_type> rbf_trainer;
one_vs_one_decision_function<ovo_trainer> df;
};
// END_OF_CLASS_DEFINITION
//////////////////////////////////////////////////////////////////////
//// Learn ////
//////////////////////////////////////////////////////////////////////
inline sample_type Learn::vectorToSample(std::vector<double> sample_) {
sample_type sample(sample_.size());
for (int i = 0; i < sample_.size(); i++) {
sample(i) = sample_.at(i);
}
return sample;
}
//////////////////////////////////////////////////////////////////////
//// Supervised ////
//////////////////////////////////////////////////////////////////////
void LearnSupervised::addSample(sample_type sample, double label) {
if (label < 0.0 || label > 1.0) {
std::cerr << "label should be between 0.0 and 1.0" << std::endl;
}
samples.push_back(sample);
labels.push_back(label);
}
void LearnSupervised::addSample(std::vector<double> sample, double label) {
if (label < 0.0 || label > 1.0) {
std::cerr << "label should be between 0.0 and 1.0" << std::endl;
}
sample_type tmp(sample.size());
for (int i = 0; i < sample.size(); i++) {
tmp(i) = sample.at(i);
}
samples.push_back(tmp);
labels.push_back(label);
}
void LearnSupervised::clearTrainingInstances() {
samples.clear();
labels.clear();
}
LearnSVM::LearnSVM() : LearnSupervised() {
}
LearnSVM::~LearnSVM() {
}
void LearnSVM::train() {
rbf_trainer.set_kernel(rbf_kernel_type(0.1));
rbf_trainer.set_lambda(0.01);
trainer.set_trainer(rbf_trainer);
randomize_samples(samples, labels);
df = trainer.train(samples, labels);
}
void LearnSVM::trainWithGridParameterSearch() {
}
double LearnSVM::predict(sample_type & sample) {
return df(sample);
}
double LearnSVM::predict(std::vector<double> & sample) {
return df(vectorToSample(sample));
}
void LearnSVM::save(string path) {
const char *filepath = path.c_str();
ofstream fout(filepath, ios::binary);
one_vs_one_decision_function<ovo_trainer, decision_function<rbf_kernel_type>> df2;
df2 = df;
serialize(df2, fout);
}
void LearnSVM::load(string path) {
const char *filepath = path.c_str();
ifstream fin(filepath, ios::binary);
one_vs_one_decision_function<ovo_trainer, decision_function<rbf_kernel_type>> df2;
deserialize(df2, fin);
df = df2;
}
double interpolate(double __inValue, double __minInRange, double __maxInRange, double __minOutRange, double __maxOutRange) {
double tmp = __inValue / (__maxInRange - __minInRange);
return __minOutRange + (__maxOutRange - __minOutRange) * tmp;
}
double clamp(double value, double lower, double upper) {
return std::max(lower, std::min(value, upper));
}
int main(void) {
LearnSVM classifier;
dlib::rand rng;
double test;
//std::cout << clamp(interpolate(2, 0, 10), 0, 1) << std::endl;
//for (int i = 0; i < 300; i++)
//{
// double x = rng.get_double_in_range(0, 3000);
// double y = 0.00074 * (x * x) + 0.0095 * x + rng.get_double_in_range(-80, 80);
// x = clamp(interpolate(x, 0, 3000, 0, 1), 0, 1);
// y = clamp(interpolate(y, 0, 3000, 0, 1), 0, 1);
// std::cout << x << "\t" << y << std::endl;
// std::vector<double> sample;
// sample.push_back(x);
// classifier.addSample(sample, y);
//}
TIME_THIS(classifier.train());
return 0;
}
以上是关于c_cpp dlib_LearnSVM的主要内容,如果未能解决你的问题,请参考以下文章