简单HOG+SVM mnist手写数字分类
Posted 学渣
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了简单HOG+SVM mnist手写数字分类相关的知识,希望对你有一定的参考价值。
使用工具 :VS2013 + OpenCV 3.1
数据集:minst
训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml
数据准备
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
首先我们利用matlab将数据转换成 .bmp 图片格式
fid_image=fopen(\'train-images.idx3-ubyte\',\'r\'); fid_label=fopen(\'train-labels.idx1-ubyte\',\'r\'); % Read the first 16 Bytes magicnumber=fread(fid_image,4); size=fread(fid_image,4); row=fread(fid_image,4); col=fread(fid_image,4); % Read the first 8 Bytes extra=fread(fid_label,8); % Read labels related to images imageIndex=fread(fid_label); Num=length(imageIndex); % Count repeat times of 0 to 9 cnt=zeros(1,10); for k=1:Num image=fread(fid_image,[max(row),max(col)]); % Get image data val=imageIndex(k); % Get value of image for i=0:9 if val==i cnt(val+1)=cnt(val+1)+1; end end if cnt(val+1)<10 str=[num2str(val),\'_000\',num2str(cnt(val+1)),\'.bmp\']; elseif cnt(val+1)<100 str=[num2str(val),\'_00\',num2str(cnt(val+1)),\'.bmp\']; elseif cnt(val+1)<1000 str=[num2str(val),\'_0\',num2str(cnt(val+1)),\'.bmp\']; else str=[num2str(val),\'_\',num2str(cnt(val+1)),\'.bmp\']; end imwrite(image\',str); end fclose(fid_image); fclose(fid_label);
然后使用cmd指令写入图片路径: dir /b/s/p/w *.bmp > num.txt 添加标签,如下图
然后打乱样本顺序。
训练
int main0() { vector<string> img_path;//输入文件名变量 vector<int> img_catg; int nLine = 0; string line; size_t pos; ifstream svm_data("./train-images/random.txt");//训练样本图片的路径都写在这个txt文件中,使用bat批处理文件可以得到这个txt文件 unsigned long n; while (svm_data)//将训练样本文件依次读取进来 { if (getline(svm_data, line)) { nLine++; pos = line.find_last_of(\' \'); img_path.push_back(line.substr(0, pos));//图像路径 img_catg.push_back(atoi(line.substr(pos + 1).c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错 } } svm_data.close();//关闭文件 int nImgNum = nLine; //nImgNum是样本数量,只有文本行数的一半,另一半是标签 cv::Mat data_mat(nImgNum, 324, CV_32FC1);//第二个参数,即矩阵的列是由下面的descriptors的大小决定的,可以由descriptors.size()得到,且对于不同大小的输入训练图片,这个值是不同的 data_mat.setTo(cv::Scalar(0)); //类型矩阵,存储每个样本的类型标志 cv::Mat res_mat(nImgNum, 1, CV_32S); res_mat.setTo(cv::Scalar(0)); cv::Mat src; cv::Mat trainImg(cv::Size(28, 28), 8, 3);//需要分析的图片,这里默认设定图片是28*28大小,所以上面定义了324,如果要更改图片大小,可以先用debug查看一下descriptors是多少,然后设定好再运行 //处理HOG特征 for (string::size_type i = 0; i != img_path.size(); i++) { src = cv::imread(img_path[i].c_str(), 1); if (src.data == NULL)//if (src == NULL) { cout << " can not load the image: " << img_path[i].c_str() << endl; continue; } //cout << " 处理: " << img_path[i].c_str() << endl; cv::resize(src, trainImg, trainImg.size()); cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(28, 28), cv::Size(14, 14), cv::Size(7, 7), cv::Size(7, 7), 9); vector<float>descriptors;//存放结果 hog->compute(trainImg, descriptors, cv::Size(1, 1), cv::Size(0, 0)); //Hog特征计算 //cout << "HOG dims: " << descriptors.size() << endl; n = 0; for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++) { //cvmSet(data_mat, i, n, *iter); data_mat.at<float>(i, n) = *iter;//存储HOG特征 n++; } //cvmSet(res_mat, i, 0, img_catg[i]); res_mat.at<int>(i, 0) = img_catg[i]; //cout << " 处理完毕: " << img_path[i].c_str() << " " << img_catg[i] << endl; } cout << "computed features!" << endl; cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();//新建一个SVM svm->setType(cv::ml::SVM::C_SVC); svm->setKernel(cv::ml::SVM::LINEAR); svm->setC(1); //-------------------不使用参数优化-------------------------// svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON)); svm->train(data_mat, cv::ml::ROW_SAMPLE, res_mat);//训练数据 //-------------------参数优化-------------------------// //svm->setTermCriteria = cv::TermCriteria(cv::TermCriteria::MAX_ITER, (int)1e7, 1e-6); //cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, res_mat); //svm->trainAuto(td, 10); //保存训练好的分类器 svm->save("HOG_SVM_DATA.xml"); cout << "saved model!" << endl; //检测样本 cv::Mat test;//IplImage *test; char result[512]; vector<string> img_test_path; vector<int> img_test_catg; int coorect = 0; ifstream img_tst("./test-images/random.txt"); //加载需要预测的图片集合,这个文本里存放的是图片全路径,不要标签 while (img_tst) { if (getline(img_tst, line)) { pos = line.find_last_of(\' \'); img_test_catg.push_back(atoi(line.substr(pos + 1).c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错 img_test_path.push_back(line.substr(0, pos));//图像路径 } } img_tst.close(); ofstream predict_txt("SVM_PREDICT.txt");//把预测结果存储在这个文本中 for (string::size_type j = 0; j != img_test_path.size(); j++)//依次遍历所有的待检测图片 { test = cv::imread(img_test_path[j].c_str(), 1); if (test.data == NULL)//test == NULL { cout << " can not load the image: " << img_test_path[j].c_str() << endl; continue; } cv::Mat trainTempImg(cv::Size(28, 28), 8, 3); trainTempImg.setTo(cv::Scalar(0)); cv::resize(test, trainTempImg, trainTempImg.size()); cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(28, 28), cv::Size(14, 14), cv::Size(7, 7), cv::Size(7, 7), 9); vector<float>descriptors;//结果数组 hog->compute(trainTempImg, descriptors, cv::Size(1, 1), cv::Size(0, 0)); //cout << "HOG dims: " << descriptors.size() << endl; cv::Mat SVMtrainMat(1, descriptors.size(), CV_32FC1); int n = 0; for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++) { SVMtrainMat.at<float>(0, n) = *iter; n++; } int ret = svm->predict(SVMtrainMat);//检测结果 if (ret == img_test_catg[j]) coorect++; sprintf(result, "%s %d\\r\\n", img_test_path[j].c_str(), ret); predict_txt << result; //输出检测结果到文本 } predict_txt.close(); cout << coorect*100 / img_test_path.size() << "%" << endl; return 0; }
测试
int main(int argc, char* argv[]) { cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create(); svm = cv::ml::SVM::load("HOG_SVM_DATA.xml");;//加载训练好的xml文件,这里训练的是10K个手写数字 //检测样本 cv::Mat test; char result[300]; //存放预测结果 test = cv::imread("6.bmp", 1); //待预测图片,用系统自带的画图工具随便手写 if (!test.data) { MessageBox(NULL, TEXT("待预测图像不存在!"), TEXT("提示"), MB_ICONWARNING); return -1; } cv::Mat trainTempImg(cv::Size(28, 28), 8, 3); trainTempImg.setTo(cv::Scalar(0)); cv::resize(test, trainTempImg, trainTempImg.size()); cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(28, 28), cv::Size(14, 14), cv::Size(7, 7), cv::Size(7, 7), 9); vector<float>descriptors;//结果数组 hog->compute(trainTempImg, descriptors, cv::Size(1, 1), cv::Size(0, 0)); //cout << "HOG dims: " << descriptors.size() << endl; cv::Mat SVMtrainMat(1, descriptors.size(), CV_32FC1); int n = 0; for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++) { SVMtrainMat.at<float>(0, n) = *iter; n++; } int ret = svm->predict(SVMtrainMat);//检测结果 sprintf(result, "%d\\r\\n", ret); cv::namedWindow("dst", 0); cv::imshow("dst", test); MessageBox(NULL, result, TEXT("预测结果"), MB_OK); return 0; }
以上是关于简单HOG+SVM mnist手写数字分类的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch Note25 深层神经网络实现 MNIST 手写数字分类