机器学习 demo分西瓜

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习 demo分西瓜相关的知识,希望对你有一定的参考价值。

周老师的书,对神经网络写了一个小的Demo

是最简单的神经网络,只有一层的隐藏层。

这次练习依旧是对西瓜的好坏进行预测。

 

主要分了以下几个步骤

1、数据预处理

对西瓜的不同特性进行数学编码表示(0~1),我是直接编了对应数字。含糖量已经是一个0~1之间的数,所以就没有进行处理

 

青绿  1

乌黑 0.5

浅白  0

 

蜷缩  1

稍蜷 0.5

硬挺  0

 

浊响  1

沉闷 0.5

清脆  0

 

清晰  1

稍糊 0.5

模糊  0

 

凹陷  1

稍凹 0.5

平坦  0

 

硬滑  1

软黏  0

2、训练集和检测集

 

[java] view plain copy
 
  1. package BP;  
  2.   
  3. public class TrainData {  
  4.     double[][] traindata;  
  5.     double[][] traindataoutput;  
  6.     double[][] testdata;  
  7.     double[][] testdataoutput;  
  8.     public TrainData(){  
  9.         traindata = new double[][]{  
  10.             new double[]{1,1,1,1,1,1,0.697,0.460},    
  11.             new double[]{0.5,1,0.5,1,1,1,0.774,0.376},  
  12.             new double[]{0.5,1,1,1,1,1,0.634,0.264},  
  13.             //new double[]{1,1,0.5,1,1,1,0.608,0.318,1},  
  14.             //new double[]{0,1,1,1,1,1,0.556,0.215,1},  
  15.             new double[]{1,0.5,1,1,0.5,0,0.403,0.237},  
  16.             new double[]{0.5,0.5,1,0.5,0.5,0,0.481,0.149},  
  17.             //new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211,1},  
  18.               
  19.             //new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091,0},  
  20.             //new double[]{1,0,0,1,0,0,0.243,0.267,0},  
  21.             //new double[]{0,0,0,0,0,1,0.245,0.057,0},  
  22.             //new double[]{0,1,1,0,0,0,0.343,0.099,0},  
  23.             new double[]{1,0.5,1,0.5,1,1,0.639,0.161},  
  24.             new double[]{0,0.5,0,0.5,1,1,0.657,0.198},  
  25.             new double[]{0.5,0.5,1,1,0.5,0,0.360,0.370},  
  26.             new double[]{0,1,1,0,0,1,0.593,0.042},  
  27.             new double[]{1,1,0.5,0.5,0.5,1,0.719,0.103}  
  28.         };  
  29.         traindataoutput = new double[][]{  
  30.             new double[]{1},  
  31.             new double[]{1},  
  32.             new double[]{1},  
  33.             new double[]{1},  
  34.             new double[]{1},  
  35.             new double[]{0},  
  36.             new double[]{0},  
  37.             new double[]{0},  
  38.             new double[]{0},  
  39.             new double[]{0},  
  40.         };  
  41.         testdata = new double[][]{  
  42.             new double[]{1,1,0.5,1,1,1,0.608,0.318},  
  43.             new double[]{0,1,1,1,1,1,0.556,0.215},  
  44.             new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211},  
  45.               
  46.             new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091},  
  47.             new double[]{1,0,0,1,0,0,0.243,0.267},  
  48.             new double[]{0,0,0,0,0,1,0.245,0.057},  
  49.             new double[]{0,1,1,0,0,0,0.343,0.099},  
  50.         };  
  51.         testdataoutput = new double[][]{  
  52.             new double[]{1},  
  53.             new double[]{1},  
  54.             new double[]{1},  
  55.             new double[]{0},  
  56.             new double[]{0},  
  57.             new double[]{0},  
  58.             new double[]{0},  
  59.         };  
  60.     }  
  61.     public static void main(String[] args){  
  62.         TrainData t = new TrainData();  
  63.         for(int i=0;i<t.traindata.length;i++){  
  64.             for(int j=0;j<9;j++)  
  65.                 System.out.print(t.traindata[i][j]+ " ");  
  66.             System.out.println();  
  67.         }  
  68.     }  
  69. }  

3、BP主函数

 

 

[java] view plain copy
 
  1. package BP;  
  2.   
  3. import java.util.Random;  
  4.   
  5. public class BP {  
  6.     int innum;  
  7.     int hiddennum;  
  8.     int outnum;  
  9.     //输入、隐藏、输出层  
  10.     public double[] input;  
  11.     public double[] hidden;  
  12.     //output为本神经网络计算出的输出值  
  13.     public double[] output;  
  14.   
  15.     //realoutput为训练网络时,用户提供的真的输出值  
  16.     public double[] realoutput;  
  17.   
  18.     //v[i,j]表示输入层i到隐层j  w[i,j]表示隐层i到输出层j  
  19.     public double[][] v;  
  20.     public double[][] w;  
  21.   
  22.     //beta为隐层的阈值,afa为输出层阈值  
  23.     public double[] beta;  
  24.     public double[] afa;  
  25.   
  26.     //学习率  
  27.     public double eta;  
  28.     //步长  
  29.     public double momentum;  
  30.     public final Random random;  
  31.   
  32.     public BP(int inputnum,int hiddennum,int outputnum,double learningrate){  
  33.         innum = inputnum;  
  34.         this.hiddennum = hiddennum;  
  35.         outnum = outputnum;  
  36.   
  37.         input = new double[inputnum + 1];  
  38.         hidden = new double[hiddennum + 1];  
  39.         output = new double[outputnum + 1];  
  40.         realoutput = new double[outputnum + 1];  
  41.   
  42.         v = new double[inputnum + 1][hiddennum + 1];  
  43.         w = new double[hiddennum + 1][outputnum + 1];  
  44.   
  45.         beta = new double[outputnum + 1];  
  46.         afa = new double[hiddennum + 1];  
  47.         for(int i=0;i<outputnum;i++)  
  48.             beta[i] = 0.0;  
  49.         for(int i=0;i<hiddennum;i++)  
  50.             afa[i] = 0.0;  
  51.   
  52.         eta = learningrate;  
  53.         //随机数对结果影响较大  
  54.         random = new Random(19950326);  
  55.         randomizeWeights(w);  
  56.         randomizeWeights(v);  
  57.     }  
  58.   
  59.     public void testData(double[] in){  
  60.         input = in;  
  61.         getNetOutput();  
  62.     }  
  63.     //只对本题目有用,output>0.5时为好西瓜,output<0.5时为坏西瓜  
  64.     public int predict(double[] in){  
  65.         testData(in);  
  66.         if(output[0]>0.5)  
  67.             return 1;  
  68.         else  
  69.             return 0;  
  70.     }  
  71.     //获得在test集上的正确率  
  72.     public double getAccuracy(double[][] in,double[][] out){  
  73.         int rightans = 0,wrongans = 0;  
  74.         for(int i=0;i<in.length;i++){  
  75.             if(predict(in[i])==(out[i][0])){  
  76.                 //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);  
  77.                 rightans++;  
  78.             }else{  
  79.                 //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);  
  80.                 wrongans++;  
  81.             }  
  82.         }  
  83.         System.out.println("对:"+rightans+" 错:"+wrongans);  
  84.         return (double)rightans/(double)(rightans+wrongans);  
  85.     }  
  86.     //times为进行几轮训练  
  87.     public void train(int times){  
  88.         TrainData t = new TrainData();  
  89.         double wu = 0.0,acc = 0.0;  
  90.         int n = t.traindata.length;  
  91.         for(int i=0;i<times;i++){  
  92.             wu = 0.0;  
  93.             for(int j=0;j<n;j++){  
  94.                 traindata(t.traindata[j],t.traindataoutput[j]);  
  95.                 wu += getDeviation();  
  96.             }  
  97.             wu = wu/((double)n);  
  98.             System.out.println("第"+i+"轮训练:"+wu);  
  99.             acc = getAccuracy(t.testdata,t.testdataoutput);  
  100.             System.out.println("预测正确率为: "+acc);  
  101.         }  
  102.     }  
  103.     //对一个input输入进行训练  
  104.     public void traindata(double[] in,double[] out){  
  105.         input = in;  
  106.         realoutput = out;  
  107.         getNetOutput();  
  108.         adjustParameter();  
  109.     }  
  110.     //获得误差E  
  111.     public double getDeviation(){  
  112.         double e = 0.0;  
  113.         for(int i=0;i<outnum;i++)  
  114.             e += (output[i] - realoutput[i])*(output[i] - realoutput[i]);  
  115.         e *= 0.5;  
  116.         return e;  
  117.     }  
  118.     //调整权值  
  119.     public void adjustParameter(){  
  120.         double g[],e = 0.0;  
  121.         g = new double[outnum];  
  122.         int i,j;  
  123.         for(i=0;i<outnum;i++){  
  124.             g[i] = output[i]*(1-output[i])*(realoutput[i]-output[i]);  
  125.             beta[i] -= eta * g[i];  
  126.             for(j=0;j<hiddennum;j++){  
  127.                 w[j][i] += eta * g[i] * hidden[j];  
  128.             }  
  129.         }  
  130.         for(i=0;i<hiddennum;i++){  
  131.             e = 0.0;  
  132.             for(j=0;j<outnum;j++)  
  133.                 e += g[j]*w[i][j];  
  134.             e = hidden[i]*(1-hidden[i])*e;  
  135.             afa[i] -= eta * e;  
  136.             for(j=0;j<innum;j++)  
  137.                 v[j][i] += eta * e * input[j];  
  138.         }  
  139.     }  
  140.     //获得output  
  141.     public void getNetOutput(){  
  142.         int i,j;  
  143.         double tmp=0.0;  
  144.         for(i=0;i<hiddennum;i++){  
  145.             tmp = 0.0;  
  146.             for(j=0;j<innum;j++)  
  147.                 tmp += v[j][i]*input[j];  
  148.             hidden[i] = sigmoid(tmp-afa[i]);  
  149.         }  
  150.         for(i=0;i<outnum;i++){  
  151.             tmp = 0.0;  
  152.             for(j=0;j<hiddennum;j++)  
  153.                 tmp += w[j][i]*hidden[j];  
  154.             output[i] = sigmoid(tmp-beta[i]);  
  155.         }  
  156.     }  
  157.     //对权值矩阵w、v进行初始随机化  
  158.     private void randomizeWeights(double[][] matrix) {  
  159.         for (int i = 0, len = matrix.length; i != len; i++)  
  160.             for (int j = 0, len2 = matrix[i].length; j != len2; j++) {  
  161.                 double real = random.nextDouble();  
  162.                 matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;  
  163.             }  
  164.     }  
  165.     public void debug(){  
  166.         System.out.println("========begin=======");  
  167.         for(int i=0;i<innum;i++){  
  168.             for(int j=0;j<hiddennum;j++)  
  169.                 System.out.print(v[i][j]+" ");  
  170.             System.out.println();  
  171.         }  
  172.         System.out.println();  
  173.         for(int i=0;i<hiddennum;i++){  
  174.             for(int j=0;j<outnum;j++)  
  175.                 System.out.print(w[i][j]+" ");  
  176.             System.out.println();  
  177.         }  
  178.         System.out.println("========end=======");  
  179.     }  
  180.     public double sigmoid(double z){  
  181.         double s = 0.0;  
  182.         s = 1d/(1d + Math.exp(-z));  
  183.         return s;  
  184.     }  
  185.   
  186.     public static void main(String[] args){  
  187.         BP bp = new BP(8,10,1,0.1);  
  188.         bp.train(50);  
  189.     }  
  190. }  


我要说的:

 

        就结果来说,在验证集上的正确率可达到85%,当然很大程度上取决于BP初始化时random函数的种子。运气好的时候甚至能达到100%的正确率,运气不好的时候只有40%多,跟随便乱猜没什么区别。

        想问大神。。。只能采用这种随机算法来找到一个最合适的ramdom种子值嘛?能不能用遗传这样的开放式算法进行搜索来找到最合适的随机值(我觉得随机的种子和随机结果并没有什么直接的关联,所以不知道能不能用遗传算法之列。。。)


以上是关于机器学习 demo分西瓜的主要内容,如果未能解决你的问题,请参考以下文章

《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”

建一个网站,用机器学习挑西瓜

西瓜书笔记:机器学习相关会议及期刊

机器学习基础概念之监督学习与无监督学习

《西瓜书机器学习详细公式推导版》发布

西瓜书笔记:机器学习相关会议及期刊