fasttext源码剖析

Posted miner007

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了fasttext源码剖析相关的知识,希望对你有一定的参考价值。

目的:记录结合多方资料以及个人理解的剖析代码;

https://heleifz.github.io/14732610572844.html

http://www.cnblogs.com/peghoty/p/3857839.html

一:代码总体模块关联图:

核心模块是fasttext.cc以及model.cc模块,但是辅助模块也很重要,是代码的螺丝钉,以及实现了数据采取什么样子数据结构进行组织,这里的东西值得学习借鉴,而且你会发现存储训练数据的结构比较常用的手段,后期可以对比多个源码的训练数据的结构对比。

部分:螺丝钉代码的剖析

二:dictionary模版

  1 /**
  2  * Copyright (c) 2016-present, Facebook, Inc.
  3  * All rights reserved.
  4  *
  5  * This source code is licensed under the BSD-style license found in the
  6  * LICENSE file in the root directory of this source tree. An additional grant
  7  * of patent rights can be found in the PATENTS file in the same directory.
  8  */
  9 
 10 #include "dictionary.h"
 11 
 12 #include <assert.h>
 13 
 14 #include <iostream>
 15 #include <algorithm>
 16 #include <iterator>
 17 #include <unordered_map>
 18 
 19 namespace fasttext {
 20 
 21 const std::string Dictionary::EOS = "</s>";
 22 const std::string Dictionary::BOW = "<";
 23 const std::string Dictionary::EOW = ">";
 24 
 25 Dictionary::Dictionary(std::shared_ptr<Args> args) {
 26   args_ = args;
 27   size_ = 0;
 28   nwords_ = 0;
 29   nlabels_ = 0;
 30   ntokens_ = 0;
 31   word2int_.resize(MAX_VOCAB_SIZE);//建立全词的索引,hash值在0~MAX_VOCAB_SIZE-1之间
 32   for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
 33     word2int_[i] = -1;
 34   }
 35 }
 36 //根据字符串,进行hash,hash后若是冲突则线性探索,找到其对应的hash位置
 37 int32_t Dictionary::find(const std::string& w) const {
 38   int32_t h = hash(w) % MAX_VOCAB_SIZE;
 39   while (word2int_[h] != -1 && words_[word2int_[h]].word != w) {
 40     h = (h + 1) % MAX_VOCAB_SIZE;
 41   }
 42   return h;
 43 }
 44 //向words_添加词,词可能是标签词
 45 void Dictionary::add(const std::string& w) {
 46   int32_t h = find(w);
 47   ntokens_++;//已处理的词
 48   if (word2int_[h] == -1) {
 49     entry e;
 50     e.word = w;
 51     e.count = 1;
 52     e.type = (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;//与给出标签相同,则表示标签词
 53     words_.push_back(e);
 54     word2int_[h] = size_++;
 55   } else {
 56     words_[word2int_[h]].count++;
 57   }
 58 }
 59 //返回纯词个数--去重
 60 int32_t Dictionary::nwords() const {
 61   return nwords_;
 62 }
 63 //标签词个数---去重
 64 int32_t Dictionary::nlabels() const {
 65   return nlabels_;
 66 }
 67 //返回已经处理的词数---可以重复
 68 int64_t Dictionary::ntokens() const {
 69   return ntokens_;
 70 }
 71 //获取纯词的ngram
 72 const std::vector<int32_t>& Dictionary::getNgrams(int32_t i) const {
 73   assert(i >= 0);
 74   assert(i < nwords_);
 75   return words_[i].subwords;
 76 }
 77 //获取纯词的ngram,根据词串
 78 const std::vector<int32_t> Dictionary::getNgrams(const std::string& word) const {
 79   int32_t i = getId(word);
 80   if (i >= 0) {
 81     return getNgrams(i);
 82   }
 83   //若是该词没有被入库词典中,未知词,则计算ngram
 84   //这就可以通过其他词的近似ngram来获取该词的ngram
 85   std::vector<int32_t> ngrams;
 86   computeNgrams(BOW + word + EOW, ngrams);
 87   return ngrams;
 88 }
 89 //是否丢弃的判断标准---这是由于无用词会出现过多的词频,需要被丢弃,
 90 bool Dictionary::discard(int32_t id, real rand) const {
 91   assert(id >= 0);
 92   assert(id < nwords_);
 93   if (args_->model == model_name::sup) return false;//非词向量不需要丢弃
 94   return rand > pdiscard_[id];
 95 }
 96 //获取词的id号
 97 int32_t Dictionary::getId(const std::string& w) const {
 98   int32_t h = find(w);
 99   return word2int_[h];
100 }
101 //词的类型
102 entry_type Dictionary::getType(int32_t id) const {
103   assert(id >= 0);
104   assert(id < size_);
105   return words_[id].type;
106 }
107 //根据词id获取词串
108 std::string Dictionary::getWord(int32_t id) const {
109   assert(id >= 0);
110   assert(id < size_);
111   return words_[id].word;
112 }
113 //hash规则
114 uint32_t Dictionary::hash(const std::string& str) const {
115   uint32_t h = 2166136261;
116   for (size_t i = 0; i < str.size(); i++) {
117     h = h ^ uint32_t(str[i]);
118     h = h * 16777619;
119   }
120   return h;
121 }
122 //根据词计算其ngram情况
123 void Dictionary::computeNgrams(const std::string& word,
124                                std::vector<int32_t>& ngrams) const {
125   for (size_t i = 0; i < word.size(); i++) {
126     std::string ngram;
127     if ((word[i] & 0xC0) == 0x80) continue;
128     for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {//n-1个词背景
129       ngram.push_back(word[j++]);
130       while (j < word.size() && (word[j] & 0xC0) == 0x80) {
131         ngram.push_back(word[j++]);
132       }
133       if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
134         int32_t h = hash(ngram) % args_->bucket;//hash余数值
135         ngrams.push_back(nwords_ + h);
136       }
137     }
138   }
139 }
140 //初始化ngram值
141 void Dictionary::initNgrams() {
142   for (size_t i = 0; i < size_; i++) {
143     std::string word = BOW + words_[i].word + EOW;
144     words_[i].subwords.push_back(i);
145     computeNgrams(word, words_[i].subwords);
146   }
147 }
148 //读取词
149 bool Dictionary::readWord(std::istream& in, std::string& word) const
150 {
151   char c;
152   std::streambuf& sb = *in.rdbuf();
153   word.clear();
154   while ((c = sb.sbumpc()) != EOF) {
155     if (c == \' \' || c == \'\\n\' || c == \'\\r\' || c == \'\\t\' || c == \'\\v\' || c == \'\\f\' || c == \'\\0\') {
156       if (word.empty()) {
157         if (c == \'\\n\') {//若是空行,则增加一个EOS
158           word += EOS;
159           return true;
160         }
161         continue;
162       } else {
163         if (c == \'\\n\')
164           sb.sungetc();//放回,体现对于换行符会用EOS替换
165         return true;
166       }
167     }
168     word.push_back(c);
169   }
170   // trigger eofbit
171   in.get();
172   return !word.empty();
173 }
174 //读取文件---获取词典;初始化舍弃规则,初始化ngram
175 void Dictionary::readFromFile(std::istream& in) {
176   std::string word;
177   int64_t minThreshold = 1;//阈值
178   while (readWord(in, word)) {
179     add(word);
180     if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
181       std::cout << "\\rRead " << ntokens_  / 1000000 << "M words" << std::flush;
182     }
183     if (size_ > 0.75 * MAX_VOCAB_SIZE) {//词保证是不超过75%
184       minThreshold++;
185       threshold(minThreshold, minThreshold);//过滤小于minThreshold的词,顺便排序了
186     }
187   }
188   threshold(args_->minCount, args_->minCountLabel);//目的是排序,顺带过滤词,指定过滤
189   
190   initTableDiscard();
191   initNgrams();
192   if (args_->verbose > 0) {
193     std::cout << "\\rRead " << ntokens_  / 1000000 << "M words" << std::endl;
194     std::cout << "Number of words:  " << nwords_ << std::endl;
195     std::cout << "Number of labels: " << nlabels_ << std::endl;
196   }
197   if (size_ == 0) {
198     std::cerr << "Empty vocabulary. Try a smaller -minCount value." << std::endl;
199     exit(EXIT_FAILURE);
200   }
201 }
202 //缩减词,且排序词
203 void Dictionary::threshold(int64_t t, int64_t tl) {
204   sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) {
205       if (e1.type != e2.type) return e1.type < e2.type;//不同类型词,将标签词排在后面
206       return e1.count > e2.count;//同类则词频降序排
207     });//排序,根据词频
208   words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) {
209         return (e.type == entry_type::word && e.count < t) ||
210                (e.type == entry_type::label && e.count < tl);
211       }), words_.end());//删除阈值以下的词
212   words_.shrink_to_fit();//剔除
213   //更新词典的信息
214   size_ = 0;
215   nwords_ = 0;
216   nlabels_ = 0;
217   for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
218     word2int_[i] = -1;//重置
219   }
220   for (auto it = words_.begin(); it != words_.end(); ++it) {
221     int32_t h = find(it->word);//重新构造hash
222     word2int_[h] = size_++;
223     if (it->type == entry_type::word) nwords_++;
224     if (it->type == entry_type::label) nlabels_++;
225   }
226 }
227 //初始化丢弃规则---
228 void Dictionary::initTableDiscard() {//t采样的阈值,0表示全部舍弃,1表示不采样
229   pdiscard_.resize(size_);
230   for (size_t i = 0; i < size_; i++) {
231     real f = real(words_[i].count) / real(ntokens_);//f概率高
232     pdiscard_[i] = sqrt(args_->t / f) + args_->t / f;//与论文貌似不一样?????
233   }
234 }
235 //返回词的频数--所以词的词频和
236 std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
237   std::vector<int64_t> counts;
238   for (auto& w : words_) {
239     if (w.type == type) counts.push_back(w.count);
240   }
241   return counts;
242 }
243 //增加ngram,
244 void Dictionary::addNgrams(std::vector<int32_t>& line, int32_t n) const {
245   int32_t line_size = line.size();
246   for (int32_t i = 0; i < line_size; i++) {
247     uint64_t h = line[i];
248     for (int32_t j = i + 1; j < line_size && j < i + n; j++) {
249       h = h * 116049371 + line[j];
250       line.push_back(nwords_ + (h % args_->bucket));
251     }
252   }
253 }
254 //获取词行
255 int32_t Dictionary::getLine(std::istream& in,
256                             std::vector<int32_t>& words,
257                             std::vector<int32_t>& labels,
258                             std::minstd_rand& rng) const {
259   std::uniform_real_distribution<> uniform(0, 1);//均匀随机0~1
260   std::string token;
261   int32_t ntokens = 0;
262   words.clear();
263   labels.clear();
264   if (in.eof()) {
265     in.clear();
266     in.seekg(std::streampos(0));
267   }
268   while (readWord(in, token)) {
269     if (token == EOS) break;//表示一行的结束
270     int32_t wid = getId(token);
271     if (wid < 0) continue;//表示词的id木有,代表未知词,则跳过
272     entry_type type = getType(wid);
273     ntokens++;//已经获取词数
274     if (type == entry_type::word && !discard(wid, uniform(rng))) {//随机采取样,表示是否取该词
275       words.push_back(wid);//词的收集--词肯定在nwords_以下
276     }
277     if (type == entry_type::label) {//标签词全部采取,肯定在nwords_以上
278       labels.push_back(wid - nwords_);//也就是labels的值需要加上nwords才能够寻找到标签词
279     }
280     if (words.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;//词向量则有限制句子长度
281   }
282   return ntokens;
283 }
284 //获取标签词,根据的是标签词的lid
285 std::string Dictionary::getLabel(int32_t lid) const {//标签词
286   assert(lid >= 0);
287   assert(lid < nlabels_);
288   return words_[lid + nwords_].word;
289 }
290 //保存词典
291 void Dictionary::save(std::ostream& out) const {
292   out.write((char*) &size_, sizeof(int32_t));
293   out.write((char*) &nwords_, sizeof(int32_t));
294   out.write((char*) &nlabels_, sizeof(int32_t));
295   out.write((char*) &ntokens_, sizeof(int64_t));
296   for (int32_t i = 0; i < size_; i++) {//
297     entry e = words_[i];
298     out.write(e.word.data(), e.word.size() * sizeof(char));//
299     out.put(0);//字符串结束标志位
300     out.write((char*) &(e.count), sizeof(int64_t));
301     out.write((char*) &(e.type), sizeof(entry_type));
302   }
303 }
304 //加载词典
305 void Dictionary::load(std::istream& in) {
306   words_.clear();
307   for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
308     word2int_[i] = -1;
309   }
310   in.read((char*) &size_, sizeof(int32_t));
311   in.read((char*) &nwords_, sizeof(int32_t));
312   in.read((char*) &nlabels_, sizeof(int32_t));
313   in.read((char*) &ntokens_, sizeof(int64_t));
314   for (int32_t i = 0; i < size_; i++) {
315     char c;
316     entry e;
317     while ((c = in.get()) != 0) {
318       e.word.push_back(c);
319     }
320     in.read((char*) &e.count, sizeof(int64_t));
321     in.read((char*) &e.type, sizeof(entry_type));
322     words_.push_back(e);
323     word2int_[find(e.word)] = i;//建立索引
324   }
325   initTableDiscard();//初始化抛弃规则
326   initNgrams();//初始化ngram词
327 }
328 
329 }
dictionary.cc

个人觉得有必要说明的地方:

1:关于字符串映射过程,以及如何建立一套索引的,详情见下图:涉及的函数主要是find,内部实现需要hash函数建立hash规则,借助2个vector来进行关联。StrToHash(find函数)   HashToIndex(word2int数组)   IndexToStruct(words_数组)

2:初始化几个有用的表,目的是加速运行速度

1)初始化ngram表,即每个词都对应一个ngram的表的id列表。比如词 "我想你" ,通过computeNgrams函数可以计算出相应ngram的词索引,假设ngram的词最短为2,最长为3,则就是"<我","我想","想你","你>",<我想","我想你","想你>"的子词组成,这里有"<>"因为这里会自动添加这样的词的开始和结束位。这里注意代码实现中的"(word[j] & 0xC0) == 0x80)"这里是考虑utf-8的汉字情况,来使得能够取出完整的一个汉字作为一个"字"

2) 初始化initTableDiscard表,对每个词根据词的频率获取相应的丢弃概率值,若是给定的阈值小于这个表的值那么就丢弃该词,这里是因为对于频率过高的词可能就是无用词,所以丢弃。比如"的","是"等;这里的实现与论文中有点差异,这里是当表中的词小于某个值表示该丢弃,这里因为这里没有对其求1-p形式,而是p+p^2。若是同理转为同方向,则论文是p,现实是p+p^2,这样的做法是使得打压更加宽松点,也就是更多词会被当作无用词丢弃。(不知道原因)

3:外界使用该.cc的主线,一是readFromFile函数,加载词;二是getLine,获取句的词。

类似的vector.cc,matrix.cc,args.cc等代码解析如下:

  1 /**
  2  * Copyright (c) 2016-present, Facebook, Inc.
  3  * All rights reserved.
  4  *
  5  * This source code is licensed under the BSD-style license found in the
  6  * LICENSE file in the root directory of this source tree. An additional grant
  7 fasttext的原理剖析

fastText 之其源码分析

FastText总结,fastText 源码分析

《Docker 源码分析》全球首发啦!

Mybatis源码剖析:传统开发方式源码剖析

Spark源码剖析:stage划分原理与源码剖析