基于朴素贝叶斯分类器的文本分类

Posted kuailefangyuan

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于朴素贝叶斯分类器的文本分类相关的知识,希望对你有一定的参考价值。

  1. 实验要求
  2. 题目要求

1、用MapReduce算法实现贝叶斯分类器的训练过程,并输出训练模型;

2、用输出的模型对测试集文档进行分类测试。测试过程可基于单机Java程序,也可以是MapReduce程序。输出每个测试文档的分类结果;

3、利用测试文档的真实类别,计算分类模型的PrecisionRecallF1值。

2.实验环境

实验平台:VMware Workstation10

虚拟机系统:Suse11

集群环境:主机名master  ip:192.168.226.129

从机名slave1  ip:192.168.226.130

  1. 贝叶斯分类器理论介绍

贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。

应用贝叶斯分类器进行分类主要分成两阶段。第一阶段是贝叶斯统计分类器的学习阶段,即根据训练数据集训练得出训练模型;第二阶段是贝叶斯分类器的推理阶段,即根据训练模型计算属于各个分类的概率,进行分类。

贝叶斯公式如下:

技术分享

其中AB分别为两个不同的事件,P(A)A先验概率P(A|B)是已知B发生后A条件概率,也由于得自B的取值而被称作A后验概率。而上式就是用事件B的先验概率来求它的后验概率。

  1. 贝叶斯分类器训练的MapReduce算法设计

3.1贝叶斯文本分类流程图

技术分享

3.2贝叶斯文本分类详细步骤

       整个文档归类过程可以分为以下步骤:

    1. 测试数据打包。将训练数据集中的大量小文本打包为SequencedFileMapReduce程序)。
    2. 文档统计以及单词统计。将1中输出的SequencedFile作为输入,分别进行文档个数统计DocCount(MapReduce程序)和各个分类下单词个数统计WordCount(MapReduce程序)
    3. 测试数据打包。将测试数据集中的大量小文本打包为SequcedFileMapReduce程序)。
    4. 文档归类。将2中输出的文档统计和单词统计的结果,分别计算文档的先验概率和单词在各个分类下的条件概率,然后将3中输出的SequencedFile作为输入,计算测试文本属于各个分类的概率,并选择其中最大概率的分类作为该文档的所属分类。

3.3具体算法设计

一个Country有多个news.txt, 一个news 有多个word

我们所设计的算法最后是要得到随机抽取一个txt文档,它最有可能属于哪个国家类别,也就是我们要得到它属于哪个国家的概率最大,把它转化为数学公式也就是:

技术分享         3-1

为了便于比较,我们将上式取对数得到:

技术分享       3-2

其中Num(Wi)表示该txt文档中单词Wi的个数;P(C|Wi) 表示拿出一个单词Wi,其属于国家C的后验概率。根据贝叶斯公式有:P(C|W) = P(W|C)*P(C)/P(W),其中:

P(W|C):国家Cnews中单词W出现的概率,根据式3-2,不能使该概率为0,所以我们约定每个国家都至少包含每个单词一次,也就是在统计单词数量时,都自动的加1,就有:

技术分享     3-3

P(C):国家C出现的概率(正比于其所含txt文件数);

P(W):单词W在整个测试集中出现的概率。

根据上面的贝叶斯公式我们设计的MapReduce算法如下:

  1. 按比例选取测试文档,其比例大致为国家包含文档数的相对比例;
  2. Map操作:一一遍历文档,得到<<C, Wi> , 1>
  3. Reduce操作:

合并<<C, W> , 1> 得到国家C中含有单词Wi的个数<<C, Wi> , ni>+1,记为N(C,Wi)

技术分享得到国家C中含有的单词总数,记为N(C)

技术分享得到测试集中单词W的总数,记为N(W)

再由技术分享得到测试集的单词总数,记为N

则可求得P(W|C) = N(C,W)/N(C)P(C) = N(C)/NP(W) = N(W)/N

 

3.4MapReduceData Flow示意图

技术分享

技术分享

技术分享

技术分享

技术分享

  1. 源代码清单

本实验中的主要代码如下所示

4.1 SmallFilesToSequenceFileConverter.java      小文件集合打包工具类MapReduce程序

4.2 WholeFileInputFormat.java      支持类:递归读取指定目录下的所有文件

4.3 WholeFileRecordReader.java 支持类:读取单个文件的全部内容

4.4 DocCount.java     文档统计MapReduce程序

4.5 WordCount.java   单词统计MapReduce程序

4.6 DocClassification.java   测试文档分类MapReduce程序

详细代码如下:

4.1 SmallFilesToSequenceFileConverter.java 其中MapReduce关键代码如下:

publicclass SmallFilesToSequenceFileConverter extends Configured implements Tool {

 

    staticclass SequenceFileMapper extends Mapper<NullWritable, BytesWritable, Text, BytesWritable> {

 

       private String fileNameKey; // 被打包的小文件名作为key,表示为Text对象

       private String classNameKey; // 当前文档所在的分类名

 

       @Override// 重新实现setup方法,进行map任务的初始化设置

       protectedvoid setup(Context context) throws IOException, InterruptedException {

           InputSplit split = context.getInputSplit(); // context获取split

           Path path = ((FileSplit) split).getPath(); // split获取文件路径

           fileNameKey = path.getName(); // 将文件路径实例化为key对象

           classNameKey = path.getParent().getName();

       }

 

       @Override// 实现map方法

       protectedvoid map(NullWritable key, BytesWritable value, Context context)

              throws IOException, InterruptedException {

           // 注意sequencefilekeyvalue key:分类,文档名  value:文档的内容)

           context.write(new Text(classNameKey + "/" + fileNameKey), value);

       }

    }

}

4.2 WholeFileInputFormat.java 其中关键代码如下:

publicclass WholeFileInputFormat extends FileInputFormat<NullWritable, BytesWritable> {

    /**

     * <p>方法描述:递归遍历输入目录下的所有文件</p>

     * <p>备注:该写FileInputFormat,使支持多层目录的输入</p>

     *  @authormeify DateTime 2015113下午2:37:49

     *  @param fs

     *  @param path

     */

    void search(FileSystem fs, Path path) {

       try {

           if (fs.isFile(path)) {

              fileStatus.add(fs.getFileStatus(path));

           } elseif (fs.isDirectory(path)) {

              FileStatus[] fileStatus = fs.listStatus(path);

              for (inti = 0; i < fileStatus.length; i++) {

                  FileStatus fileStatu = fileStatus[i];

                  search(fs, fileStatu.getPath());

              }

           }

       } catch (IOException e) {

           e.printStackTrace();

       }

    }

    @Override

    public RecordReader<NullWritable, BytesWritable> createRecordReader(InputSplit split, TaskAttemptContext context)

           throws IOException, InterruptedException {

       WholeFileRecordReader reader = new WholeFileRecordReader();

       reader.initialize(split, context);

       returnreader;

    }

    @Override

    protected List<FileStatus> listStatus(JobContext job) throws IOException {

      

       FileSystem fs = FileSystem.get(job.getConfiguration());

       // 输入根目录

       String rootDir = job.getConfiguration().get("mapred.input.dir", "");

       // 递归获取输入目录下的所有文件

       search(fs, new Path(rootDir));

       returnthis.fileStatus;

    }

}

4.3 WholeFileRecordReader.java 其中关键代码如下:

publicclass WholeFileRecordReader extends RecordReader<NullWritable, BytesWritable>{

 

    private FileSplit fileSplit; //保存输入的分片,它将被转换成一条( key value)记录

    private Configuration conf; //配置对象

    private BytesWritable value = new BytesWritable(); //value对象,内容为空

    privatebooleanprocessed = false; //布尔变量记录记录是否被处理过

    @Override

    publicboolean nextKeyValue() throws IOException, InterruptedException {

       if (!processed) { //如果记录没有被处理过

           //fileSplit对象获取split的字节数,创建byte数组contents

           byte[] contents = newbyte[(int) fileSplit.getLength()];

           Path file = fileSplit.getPath(); //fileSplit对象获取输入文件路径

           FileSystem fs = file.getFileSystem(conf); //获取文件系统对象

           FSDataInputStream in = null; //定义文件输入流对象

           try {

              in = fs.open(file); //打开文件,返回文件输入流对象

//从输入流读取所有字节到contents

              IOUtils.readFully(in, contents, 0, contents.length);          value.set(contents, 0, contents.length); //contens内容设置到value对象中

           } finally {

              IOUtils.closeStream(in); //关闭输入流

           }

          

           processed = true; //将是否处理标志设为true,下次调用该方法会返回false

           returntrue;

       }

           returnfalse; //如果记录处理过,返回false,表示split处理完毕

    }

}

4.4 DocCount.java  其中MapReduce关键代码如下:

publicclass DocCount extends Configured implements Tool{

 

    publicstaticclass Map extends Mapper<Text, BytesWritable, Text, IntWritable> {

       @Override

       publicvoid map(Text key, BytesWritable value, Context context) {

           try {

              String currentKey = key.toString();

              String[] arr = currentKey.split("/");

              String className = arr[0];

              String fileName = arr[1];

              System.out.println(className + "," + fileName);

              context.write(new Text(className), new IntWritable(1));

           } catch (IOException e) {

              e.printStackTrace();

           } catch (InterruptedException e) {

              e.printStackTrace();

           }

       }

    }

   

    publicstaticclass Reduce extends Reducer<Text, IntWritable, Text, IntWritable> {

       private IntWritable result = new IntWritable();

       publicvoid reduce(Text key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {

           intsum = 0;

           for (IntWritable val : values) {

              sum ++;

           }

           result.set(sum);

           context.write(key, result);  // 输出结果key: 分类 ,  value: 文档个数

       }

    }

}

4.5 WordCount.java  其中MapReduce关键代码如下:

publicclass WordCount extends Configured implements Tool{

 

    publicstaticclass Map extends Mapper<Text, BytesWritable, Text, IntWritable> {

 

       @Override

       publicvoid map(Text key, BytesWritable value, Context context) {

           try {

             

              String[] arr = key.toString().split("/");

              String className = arr[0];

              String fileName = arr[1];

              value.setCapacity(value.getSize()); // 剔除多余空间

              // 文本内容

               String content = new String(value.getBytes(), 0, value.getLength());

              StringTokenizer itr = new StringTokenizer(content);

              while (itr.hasMoreTokens()) {

                  String word = itr.nextToken();

                  if(StringUtil.isValidWord(word))

                  {

                     System.out.println(className + "/" + word);

                     context.write(new Text(className + "/" + word), new IntWritable(1));

                  }

              }

           } catch (IOException e) {

              e.printStackTrace();

           } catch (InterruptedException e) {

              e.printStackTrace();

           }

       }

    }

   

    publicstaticclass Reduce extends Reducer<Text, IntWritable, Text, IntWritable> {

       private IntWritable result = new IntWritable();

       publicvoid reduce(Text key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {

           intsum = 1; // 注意这里单词的个数从1开始计数

           for (IntWritable val : values) {

              sum ++;

           }

           result.set(sum);

           context.write(key, result);  // 输出结果key: 分类/ 单词 ,  value: 频次

       }

    }

}

4.6 DocClassification.java 其中MapReduce关键代码如下:

publicclass DocClassification extends Configured implements Tool {

    // 所有分类集合

    privatestatic List<String> classList = new ArrayList<String>();

    // 所有分类的先验概率(其中的概率取对数log

    privatestatic HashMap<String, Double> classProMap = new HashMap<String, Double>();

    // 所有单词在各个分类中的出现的频次

    privatestatic HashMap<String, Integer> classWordNumMap = new HashMap<String, Integer>();

    // 分类下的所有单词出现的总频次

    privatestatic HashMap<String, Integer> classWordSumMap = new HashMap<String, Integer>();

    privatestatic Configuration conf = new Configuration();

    static {

       // 初始化分类先验概率词典

       initClassProMap("hdfs://192.168.226.129:9000/user/hadoop/doc");

       // 初始化单词在各个分类中的条件概率词典

       initClassWordProMap("hdfs://192.168.226.129:9000/user/hadoop/word");

    }

   

    publicstaticclass Map extends Mapper<Text, BytesWritable, Text, Text> {

       @Override

       publicvoid map(Text key, BytesWritable value, Context context) {

           String fileName = key.toString();

           value.setCapacity(value.getSize()); // 剔除多余空间

           String content = new String(value.getBytes(), 0, value.getLength());

           try {

              for (String className : classList) {

                  doubleresult = Math.log(classProMap.get(className));

                  StringTokenizer itr = new StringTokenizer(content);

                  while (itr.hasMoreTokens()) {

                     String word = itr.nextToken();

                     if (StringUtil.isValidWord(word)) {

                         intwordSum = 1;

                         if(classWordNumMap.get(className + "/" + word) != null){

                            wordSum = classWordNumMap.get(className + "/" + word);

                         }

                         intclassWordSum = classWordSumMap.get(className);

                         doublepro_class_word = Math.log(((double)wordSum)/classWordSum);

                         result += pro_class_word;

                     }

                  }

                  // 输出的形式 key:文件名 value:分类名/概率

                  context.write(new Text(fileName), new Text(className + "/" + String.valueOf(result)));

              }

           } catch (IOException e) {

              e.printStackTrace();

           } catch (InterruptedException e) {

              e.printStackTrace();

           }

       }

    }

 

    publicstaticclass Reduce extends Reducer<Text, Text, Text, Text> {

 

       publicvoid reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException {

           String fileName = key.toString().split("/")[1];

    doublemaxPro = Math.log(Double.MIN_VALUE);

           String maxClassName = "unknown";

           for (Text value : values) {

              String[] arr = value.toString().split("/");

              String className = arr[0];

              doublepro = Double.valueOf(arr[1]);

              if (pro > maxPro) {

                  maxPro = pro;

                  maxClassName = className;

              }

           }

           System.out.println("fileName:" + fileName + ",belong class:" + maxClassName);

           // 输出 key:文件名 value:所属分类名以及概率

           context.write(new Text(fileName), new Text(maxClassName + ",pro=" + maxPro));

       }

    }

}

四、数据集说明

训练集:CHINA  文档数255

INDIA   文档数326

TAIWAN  文档数43.

测试集:CHINA   文档个数15

INDIA    文档个数20

TAIWAN  文档个数15

  1. 程序运行说明

5.1训练数据集打包程序

Map任务个数624(所有小文件的个数)   Reduce任务个数1

截图如下

技术分享

技术分享

 

5.2训练文档统计程序

Map任务个数1(输入为1SequencedFile   Reduce任务个数1

技术分享

技术分享

5.3训练单词统计程序

Map任务个数1(输入为1SequencedFile)   Reduce任务个数1

技术分享

技术分享

5.4测试数据集打包程序

Map任务个数50(测试数据集小文件个数为50)   Reduce任务个数1

技术分享

技术分享

5.5测试文档归类程序

Map任务个数1(输入为1SequencedFile   Reduce任务个数1

技术分享

技术分享

 

  1. 实验结果分析

测试集文档归类结果截图如下:

技术分享

 针对CHINA TAIWAN INDIA三个分类下的测试文档进行测试结果如下表所示:

类别(国家)

正确率

召回率

F1

CHINA

18.4%

46.667%

26.38%

INDIA

42.1%

80%

55.67%

TAIWAN

39.47%

100%

56.60%

以上是关于基于朴素贝叶斯分类器的文本分类的主要内容,如果未能解决你的问题,请参考以下文章

《机器学习实战》基于朴素贝叶斯分类算法构建文本分类器的Python实现

性能:提高朴素贝叶斯分类器的准确性

如何生成混淆矩阵并找到朴素贝叶斯分类器的错误分类率?

朴素贝叶斯分类算法的sklearn实现

朴素贝叶斯分类器的平衡语料库

贝叶斯分类器(3)朴素贝叶斯分类器