fasttext源码剖析
Posted miner007
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了fasttext源码剖析相关的知识,希望对你有一定的参考价值。
目的:记录结合多方资料以及个人理解的剖析代码;
https://heleifz.github.io/14732610572844.html
http://www.cnblogs.com/peghoty/p/3857839.html
一:代码总体模块关联图:
核心模块是fasttext.cc以及model.cc模块,但是辅助模块也很重要,是代码的螺丝钉,以及实现了数据采取什么样子数据结构进行组织,这里的东西值得学习借鉴,而且你会发现存储训练数据的结构比较常用的手段,后期可以对比多个源码的训练数据的结构对比。
部分:螺丝钉代码的剖析
二:dictionary模版
1 /** 2 * Copyright (c) 2016-present, Facebook, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. An additional grant 7 * of patent rights can be found in the PATENTS file in the same directory. 8 */ 9 10 #include "dictionary.h" 11 12 #include <assert.h> 13 14 #include <iostream> 15 #include <algorithm> 16 #include <iterator> 17 #include <unordered_map> 18 19 namespace fasttext { 20 21 const std::string Dictionary::EOS = "</s>"; 22 const std::string Dictionary::BOW = "<"; 23 const std::string Dictionary::EOW = ">"; 24 25 Dictionary::Dictionary(std::shared_ptr<Args> args) { 26 args_ = args; 27 size_ = 0; 28 nwords_ = 0; 29 nlabels_ = 0; 30 ntokens_ = 0; 31 word2int_.resize(MAX_VOCAB_SIZE);//建立全词的索引,hash值在0~MAX_VOCAB_SIZE-1之间 32 for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 33 word2int_[i] = -1; 34 } 35 } 36 //根据字符串,进行hash,hash后若是冲突则线性探索,找到其对应的hash位置 37 int32_t Dictionary::find(const std::string& w) const { 38 int32_t h = hash(w) % MAX_VOCAB_SIZE; 39 while (word2int_[h] != -1 && words_[word2int_[h]].word != w) { 40 h = (h + 1) % MAX_VOCAB_SIZE; 41 } 42 return h; 43 } 44 //向words_添加词,词可能是标签词 45 void Dictionary::add(const std::string& w) { 46 int32_t h = find(w); 47 ntokens_++;//已处理的词 48 if (word2int_[h] == -1) { 49 entry e; 50 e.word = w; 51 e.count = 1; 52 e.type = (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;//与给出标签相同,则表示标签词 53 words_.push_back(e); 54 word2int_[h] = size_++; 55 } else { 56 words_[word2int_[h]].count++; 57 } 58 } 59 //返回纯词个数--去重 60 int32_t Dictionary::nwords() const { 61 return nwords_; 62 } 63 //标签词个数---去重 64 int32_t Dictionary::nlabels() const { 65 return nlabels_; 66 } 67 //返回已经处理的词数---可以重复 68 int64_t Dictionary::ntokens() const { 69 return ntokens_; 70 } 71 //获取纯词的ngram 72 const std::vector<int32_t>& Dictionary::getNgrams(int32_t i) const { 73 assert(i >= 0); 74 assert(i < nwords_); 75 return words_[i].subwords; 76 } 77 //获取纯词的ngram,根据词串 78 const std::vector<int32_t> Dictionary::getNgrams(const std::string& word) const { 79 int32_t i = getId(word); 80 if (i >= 0) { 81 return getNgrams(i); 82 } 83 //若是该词没有被入库词典中,未知词,则计算ngram 84 //这就可以通过其他词的近似ngram来获取该词的ngram 85 std::vector<int32_t> ngrams; 86 computeNgrams(BOW + word + EOW, ngrams); 87 return ngrams; 88 } 89 //是否丢弃的判断标准---这是由于无用词会出现过多的词频,需要被丢弃, 90 bool Dictionary::discard(int32_t id, real rand) const { 91 assert(id >= 0); 92 assert(id < nwords_); 93 if (args_->model == model_name::sup) return false;//非词向量不需要丢弃 94 return rand > pdiscard_[id]; 95 } 96 //获取词的id号 97 int32_t Dictionary::getId(const std::string& w) const { 98 int32_t h = find(w); 99 return word2int_[h]; 100 } 101 //词的类型 102 entry_type Dictionary::getType(int32_t id) const { 103 assert(id >= 0); 104 assert(id < size_); 105 return words_[id].type; 106 } 107 //根据词id获取词串 108 std::string Dictionary::getWord(int32_t id) const { 109 assert(id >= 0); 110 assert(id < size_); 111 return words_[id].word; 112 } 113 //hash规则 114 uint32_t Dictionary::hash(const std::string& str) const { 115 uint32_t h = 2166136261; 116 for (size_t i = 0; i < str.size(); i++) { 117 h = h ^ uint32_t(str[i]); 118 h = h * 16777619; 119 } 120 return h; 121 } 122 //根据词计算其ngram情况 123 void Dictionary::computeNgrams(const std::string& word, 124 std::vector<int32_t>& ngrams) const { 125 for (size_t i = 0; i < word.size(); i++) { 126 std::string ngram; 127 if ((word[i] & 0xC0) == 0x80) continue; 128 for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {//n-1个词背景 129 ngram.push_back(word[j++]); 130 while (j < word.size() && (word[j] & 0xC0) == 0x80) { 131 ngram.push_back(word[j++]); 132 } 133 if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) { 134 int32_t h = hash(ngram) % args_->bucket;//hash余数值 135 ngrams.push_back(nwords_ + h); 136 } 137 } 138 } 139 } 140 //初始化ngram值 141 void Dictionary::initNgrams() { 142 for (size_t i = 0; i < size_; i++) { 143 std::string word = BOW + words_[i].word + EOW; 144 words_[i].subwords.push_back(i); 145 computeNgrams(word, words_[i].subwords); 146 } 147 } 148 //读取词 149 bool Dictionary::readWord(std::istream& in, std::string& word) const 150 { 151 char c; 152 std::streambuf& sb = *in.rdbuf(); 153 word.clear(); 154 while ((c = sb.sbumpc()) != EOF) { 155 if (c == \' \' || c == \'\\n\' || c == \'\\r\' || c == \'\\t\' || c == \'\\v\' || c == \'\\f\' || c == \'\\0\') { 156 if (word.empty()) { 157 if (c == \'\\n\') {//若是空行,则增加一个EOS 158 word += EOS; 159 return true; 160 } 161 continue; 162 } else { 163 if (c == \'\\n\') 164 sb.sungetc();//放回,体现对于换行符会用EOS替换 165 return true; 166 } 167 } 168 word.push_back(c); 169 } 170 // trigger eofbit 171 in.get(); 172 return !word.empty(); 173 } 174 //读取文件---获取词典;初始化舍弃规则,初始化ngram 175 void Dictionary::readFromFile(std::istream& in) { 176 std::string word; 177 int64_t minThreshold = 1;//阈值 178 while (readWord(in, word)) { 179 add(word); 180 if (ntokens_ % 1000000 == 0 && args_->verbose > 1) { 181 std::cout << "\\rRead " << ntokens_ / 1000000 << "M words" << std::flush; 182 } 183 if (size_ > 0.75 * MAX_VOCAB_SIZE) {//词保证是不超过75% 184 minThreshold++; 185 threshold(minThreshold, minThreshold);//过滤小于minThreshold的词,顺便排序了 186 } 187 } 188 threshold(args_->minCount, args_->minCountLabel);//目的是排序,顺带过滤词,指定过滤 189 190 initTableDiscard(); 191 initNgrams(); 192 if (args_->verbose > 0) { 193 std::cout << "\\rRead " << ntokens_ / 1000000 << "M words" << std::endl; 194 std::cout << "Number of words: " << nwords_ << std::endl; 195 std::cout << "Number of labels: " << nlabels_ << std::endl; 196 } 197 if (size_ == 0) { 198 std::cerr << "Empty vocabulary. Try a smaller -minCount value." << std::endl; 199 exit(EXIT_FAILURE); 200 } 201 } 202 //缩减词,且排序词 203 void Dictionary::threshold(int64_t t, int64_t tl) { 204 sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) { 205 if (e1.type != e2.type) return e1.type < e2.type;//不同类型词,将标签词排在后面 206 return e1.count > e2.count;//同类则词频降序排 207 });//排序,根据词频 208 words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) { 209 return (e.type == entry_type::word && e.count < t) || 210 (e.type == entry_type::label && e.count < tl); 211 }), words_.end());//删除阈值以下的词 212 words_.shrink_to_fit();//剔除 213 //更新词典的信息 214 size_ = 0; 215 nwords_ = 0; 216 nlabels_ = 0; 217 for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 218 word2int_[i] = -1;//重置 219 } 220 for (auto it = words_.begin(); it != words_.end(); ++it) { 221 int32_t h = find(it->word);//重新构造hash 222 word2int_[h] = size_++; 223 if (it->type == entry_type::word) nwords_++; 224 if (it->type == entry_type::label) nlabels_++; 225 } 226 } 227 //初始化丢弃规则--- 228 void Dictionary::initTableDiscard() {//t采样的阈值,0表示全部舍弃,1表示不采样 229 pdiscard_.resize(size_); 230 for (size_t i = 0; i < size_; i++) { 231 real f = real(words_[i].count) / real(ntokens_);//f概率高 232 pdiscard_[i] = sqrt(args_->t / f) + args_->t / f;//与论文貌似不一样????? 233 } 234 } 235 //返回词的频数--所以词的词频和 236 std::vector<int64_t> Dictionary::getCounts(entry_type type) const { 237 std::vector<int64_t> counts; 238 for (auto& w : words_) { 239 if (w.type == type) counts.push_back(w.count); 240 } 241 return counts; 242 } 243 //增加ngram, 244 void Dictionary::addNgrams(std::vector<int32_t>& line, int32_t n) const { 245 int32_t line_size = line.size(); 246 for (int32_t i = 0; i < line_size; i++) { 247 uint64_t h = line[i]; 248 for (int32_t j = i + 1; j < line_size && j < i + n; j++) { 249 h = h * 116049371 + line[j]; 250 line.push_back(nwords_ + (h % args_->bucket)); 251 } 252 } 253 } 254 //获取词行 255 int32_t Dictionary::getLine(std::istream& in, 256 std::vector<int32_t>& words, 257 std::vector<int32_t>& labels, 258 std::minstd_rand& rng) const { 259 std::uniform_real_distribution<> uniform(0, 1);//均匀随机0~1 260 std::string token; 261 int32_t ntokens = 0; 262 words.clear(); 263 labels.clear(); 264 if (in.eof()) { 265 in.clear(); 266 in.seekg(std::streampos(0)); 267 } 268 while (readWord(in, token)) { 269 if (token == EOS) break;//表示一行的结束 270 int32_t wid = getId(token); 271 if (wid < 0) continue;//表示词的id木有,代表未知词,则跳过 272 entry_type type = getType(wid); 273 ntokens++;//已经获取词数 274 if (type == entry_type::word && !discard(wid, uniform(rng))) {//随机采取样,表示是否取该词 275 words.push_back(wid);//词的收集--词肯定在nwords_以下 276 } 277 if (type == entry_type::label) {//标签词全部采取,肯定在nwords_以上 278 labels.push_back(wid - nwords_);//也就是labels的值需要加上nwords才能够寻找到标签词 279 } 280 if (words.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;//词向量则有限制句子长度 281 } 282 return ntokens; 283 } 284 //获取标签词,根据的是标签词的lid 285 std::string Dictionary::getLabel(int32_t lid) const {//标签词 286 assert(lid >= 0); 287 assert(lid < nlabels_); 288 return words_[lid + nwords_].word; 289 } 290 //保存词典 291 void Dictionary::save(std::ostream& out) const { 292 out.write((char*) &size_, sizeof(int32_t)); 293 out.write((char*) &nwords_, sizeof(int32_t)); 294 out.write((char*) &nlabels_, sizeof(int32_t)); 295 out.write((char*) &ntokens_, sizeof(int64_t)); 296 for (int32_t i = 0; i < size_; i++) {//词 297 entry e = words_[i]; 298 out.write(e.word.data(), e.word.size() * sizeof(char));//词 299 out.put(0);//字符串结束标志位 300 out.write((char*) &(e.count), sizeof(int64_t)); 301 out.write((char*) &(e.type), sizeof(entry_type)); 302 } 303 } 304 //加载词典 305 void Dictionary::load(std::istream& in) { 306 words_.clear(); 307 for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 308 word2int_[i] = -1; 309 } 310 in.read((char*) &size_, sizeof(int32_t)); 311 in.read((char*) &nwords_, sizeof(int32_t)); 312 in.read((char*) &nlabels_, sizeof(int32_t)); 313 in.read((char*) &ntokens_, sizeof(int64_t)); 314 for (int32_t i = 0; i < size_; i++) { 315 char c; 316 entry e; 317 while ((c = in.get()) != 0) { 318 e.word.push_back(c); 319 } 320 in.read((char*) &e.count, sizeof(int64_t)); 321 in.read((char*) &e.type, sizeof(entry_type)); 322 words_.push_back(e); 323 word2int_[find(e.word)] = i;//建立索引 324 } 325 initTableDiscard();//初始化抛弃规则 326 initNgrams();//初始化ngram词 327 } 328 329 }
个人觉得有必要说明的地方:
1:关于字符串映射过程,以及如何建立一套索引的,详情见下图:涉及的函数主要是find,内部实现需要hash函数建立hash规则,借助2个vector来进行关联。StrToHash(find函数) HashToIndex(word2int数组) IndexToStruct(words_数组)
2:初始化几个有用的表,目的是加速运行速度
1)初始化ngram表,即每个词都对应一个ngram的表的id列表。比如词 "我想你" ,通过computeNgrams函数可以计算出相应ngram的词索引,假设ngram的词最短为2,最长为3,则就是"<我","我想","想你","你>",<我想","我想你","想你>"的子词组成,这里有"<>"因为这里会自动添加这样的词的开始和结束位。这里注意代码实现中的"(word[j] & 0xC0) == 0x80)"这里是考虑utf-8的汉字情况,来使得能够取出完整的一个汉字作为一个"字"
2) 初始化initTableDiscard表,对每个词根据词的频率获取相应的丢弃概率值,若是给定的阈值小于这个表的值那么就丢弃该词,这里是因为对于频率过高的词可能就是无用词,所以丢弃。比如"的","是"等;这里的实现与论文中有点差异,这里是当表中的词小于某个值表示该丢弃,这里因为这里没有对其求1-p形式,而是p+p^2。若是同理转为同方向,则论文是p,现实是p+p^2,这样的做法是使得打压更加宽松点,也就是更多词会被当作无用词丢弃。(不知道原因)
3:外界使用该.cc的主线,一是readFromFile函数,加载词;二是getLine,获取句的词。
类似的vector.cc,matrix.cc,args.cc等代码解析如下:
1 /** 2 * Copyright (c) 2016-present, Facebook, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. An additional grant 7 fasttext的原理剖析