c_cpp dlib_LearnSVM

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了c_cpp dlib_LearnSVM相关的知识,希望对你有一定的参考价值。

#include "pch.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
#include <dlib/time_this.h>
#include <dlib/algs.h>

#include <vector>
#include <iostream>
#include <cerrno>


using namespace dlib;
using namespace std;

typedef matrix<double, 0, 1> sample_type;
typedef matrix<double, 0, 1> sample_type;					// variable column vector


typedef one_vs_one_trainer<any_trainer<sample_type>, double> ovo_trainer;
typedef radial_basis_kernel<sample_type> rbf_kernel_type;

class Learn {
public:
	Learn() { }
	virtual ~Learn() { }
	virtual void train() { }
	virtual void save(std::string pathname) { }
	virtual void load(std::string pathname) { }
	inline sample_type vectorToSample(std::vector<double> sample_);
};

class LearnSupervised : public Learn {
public:
	LearnSupervised() : Learn() { }
	void addSample(std::vector<double> sample, double label);
	void addSample(sample_type sample, double label);
	virtual double predict(std::vector<double> &sample) { return 0; }
	virtual double predict(sample_type &sample) { return 0; }
	void clearTrainingInstances();
protected:
	std::vector<sample_type> samples;
	std::vector<double> labels;
};

class LearnSVM : public LearnSupervised {
public:
	LearnSVM();
	~LearnSVM();
	void train();
	void trainWithGridParameterSearch();
	double predict(std::vector<double> & sample);
	double predict(sample_type & sample);
	void save(string path);
	void load(string path);
private:
	ovo_trainer trainer;
	krr_trainer<rbf_kernel_type> rbf_trainer;
	one_vs_one_decision_function<ovo_trainer> df;
};
// END_OF_CLASS_DEFINITION

//////////////////////////////////////////////////////////////////////
////                            Learn                             ////
//////////////////////////////////////////////////////////////////////
inline sample_type Learn::vectorToSample(std::vector<double> sample_) {
	sample_type sample(sample_.size());
	for (int i = 0; i < sample_.size(); i++) {
		sample(i) = sample_.at(i);
	}
	return sample;
}

//////////////////////////////////////////////////////////////////////
////                         Supervised                           ////
//////////////////////////////////////////////////////////////////////
void LearnSupervised::addSample(sample_type sample, double label) {
	if (label < 0.0 || label > 1.0) {
		std::cerr << "label should be between 0.0 and 1.0" << std::endl;
	}
	samples.push_back(sample);
	labels.push_back(label);
}

void LearnSupervised::addSample(std::vector<double> sample, double label) {
	if (label < 0.0 || label > 1.0) {
		std::cerr << "label should be between 0.0 and 1.0" << std::endl;
	}
	sample_type tmp(sample.size());
	for (int i = 0; i < sample.size(); i++) {
		tmp(i) = sample.at(i);
	}
	samples.push_back(tmp);
	labels.push_back(label);
}

void LearnSupervised::clearTrainingInstances() {
	samples.clear();
	labels.clear();
}

LearnSVM::LearnSVM() : LearnSupervised() {

}

LearnSVM::~LearnSVM() {

}

void LearnSVM::train() {
	rbf_trainer.set_kernel(rbf_kernel_type(0.1));
	rbf_trainer.set_lambda(0.01);
	trainer.set_trainer(rbf_trainer);

	randomize_samples(samples, labels);
	df = trainer.train(samples, labels);
}


void LearnSVM::trainWithGridParameterSearch() {

}

double LearnSVM::predict(sample_type & sample) {
	return df(sample);
}

double LearnSVM::predict(std::vector<double> & sample) {
	return df(vectorToSample(sample));
}

void LearnSVM::save(string path) {
	const char *filepath = path.c_str();
	ofstream fout(filepath, ios::binary);
	one_vs_one_decision_function<ovo_trainer, decision_function<rbf_kernel_type>> df2;
	df2 = df;
	serialize(df2, fout);
}

void LearnSVM::load(string path) {
	const char *filepath = path.c_str();
	ifstream fin(filepath, ios::binary);
	one_vs_one_decision_function<ovo_trainer, decision_function<rbf_kernel_type>> df2;
	deserialize(df2, fin);
	df = df2;
}

double interpolate(double __inValue, double __minInRange, double __maxInRange, double __minOutRange, double __maxOutRange) {
	double tmp = __inValue / (__maxInRange - __minInRange);
	return __minOutRange + (__maxOutRange - __minOutRange) * tmp;
}

double clamp(double value, double lower, double upper) {
	return std::max(lower, std::min(value, upper));
}

int main(void) {
	LearnSVM classifier;
	dlib::rand rng;
	double test;

	//std::cout << clamp(interpolate(2, 0, 10), 0, 1) << std::endl;

	//for (int i = 0; i < 300; i++)
	//{
	//	double x = rng.get_double_in_range(0, 3000);
	//	double y = 0.00074 * (x * x) + 0.0095 * x + rng.get_double_in_range(-80, 80);
	//	x = clamp(interpolate(x, 0, 3000, 0, 1), 0, 1);
	//	y = clamp(interpolate(y, 0, 3000, 0, 1), 0, 1);

	//	std::cout << x << "\t" << y << std::endl;
	//	std::vector<double> sample;
	//	sample.push_back(x);

	//	classifier.addSample(sample, y);


	//}

	TIME_THIS(classifier.train());
	return 0;
}

以上是关于c_cpp dlib_LearnSVM的主要内容,如果未能解决你的问题,请参考以下文章

c_cpp 200.岛屿数量

c_cpp 127.单词阶梯

c_cpp MOFSET

c_cpp MOFSET

c_cpp 31.下一个排列

c_cpp string→char *