如何使用BM25算法检索出最相关的序列
Posted CSU迦叶
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何使用BM25算法检索出最相关的序列相关的知识,希望对你有一定的参考价值。
背景
起因
博主正在进行的科研应用到了in-context learning这个范式,与传统的学习范式不同,情境中学习并不是真的学习,即不改变模型的参数,称为in-context inference 或许更为贴切。Icl不要求存在传统的训练集,但是依然需要少数几个有标签的样例(demostration)来告诉模型要做什么任务。那么模型推理结果的好坏很大程度上就取决于这几个样例给的好不好。
那么如何得到样例,一种思路是这样的,假定已经得到了一个demo pool(也许是传统学习范式中的训练集变化来的),对于任意一个query(和demo的区别是它没有标签),什么样的demos是好的呢?我们可以直观地联想到,和query越像,这个demo也就越适合作为icl中的样例。
BM25
BM25算法于2009年在论文The Probabilistic Relevance Framework: BM25 and Beyond中提出,可能是提出早的缘故,已经有python第三方库对它做了很好的实现,2023年我们只需会调用。
from rank_bm25 import BM25Okapi
使用过程
总体思路
两个步骤:1)将准备好的demo pool传给BM25okapi,这个过程会得到一个将哈希映射到序列的缓存字典 2)使用实例化后的BM25okapi,传入query,得到最相似的n个demo的哈希,再使用上一步得到的字典映射回序列。
代码
对于步骤一,我参照仓库(https://github.com/prompt-learning/cedar),单独定义了一个工具类Util,在当中定义了静态方法load_bm_25(),如下
class Util:
@staticmethod
def load_bm_25(bm_25_cache_dict, test_methods, demoPool:List[Oracle_datapoint_with_demo_length]):
start_time = time.time()
how_many_md5hash_conflicts = 0
for dp in demoPool:
tokenized_test_method = dp.datapoint.test_method.split(" ") # 匹配datapoint哪些元素之间的相似度是这里决定的
md5hash = hashlib.md5(" ".join(tokenized_test_method).encode('utf-8')).hexdigest()
if md5hash in bm_25_cache_dict:
how_many_md5hash_conflicts += 1
else:
bm_25_cache_dict[md5hash] = dp
test_methods.append(dp.datapoint.test_method)
print("how_many_md5hash_conflicts: ", how_many_md5hash_conflicts)
bm25 = BM25Okapi(test_methods)
end_time = time.time()
print("load_bm_25: ", end_time - start_time)
print("The size of the bm25 cache is bytes".format(sys.getsizeof(bm_25_cache_dict)))
print(f"total entries: len(bm_25_cache_dict.keys())")
return bm25
load_bm_25()有三个参数,分别是bm_25_cache_dict,test_methods,demoPool.下面分别说一下他们的作用:
首先前两个参数,传进去的分别是空字典和空列表,但是从函数出来以后,它们都被装载了内容,主要为了后期使用时候的校验。至于第三个参数demoPool,顾名思义,就是要把所有的候选demo给放在这个列表当中,类型是可以自定义的。但是注意看这一句
bm25 = BM25Okapi(test_methods)
真正用来初始化BM25Okapi即构成池子的,其实是demoPool的一部分,test_methods。当然也可以不这么写,直接传入有内容的test_methods,这样第三个参数也就省去了。但是这样写还是有原因的,虽然对于每一个demo,我们只用它的一部分来比较和query的相关性,但毕竟这里的demo可以直接等价于行文时所说的例子。
步骤二的代码如下:
def bm25_retrived_demos(query:Oracle_datapoint,demoPool:List[Oracle_datapoint_with_demo_length])->List[Oracle_datapoint]:
bm_25_cache_dict =
test_methods = []
bm25 = Util.load_bm_25(bm_25_cache_dict, test_methods, demoPool)
tokenized_query = query.test_method.split(" ")
results_top_n = bm25.get_top_n(tokenized_query, test_methods, n=2)
candidate_demonstrations:List[Oracle_datapoint] = []
length_so_far = 0
for r in results_top_n:
md5hash_of_query = hashlib.md5(r.encode('utf-8')).hexdigest()
if md5hash_of_query in bm_25_cache_dict:
dp = bm_25_cache_dict[md5hash_of_query]
candidate_demo_token_count = dp.token_count
if (length_so_far + candidate_demo_token_count) <= 7000: # 7000暂时取代本应计算的max_demo_length
candidate_demonstrations.append(dp.datapoint)
length_so_far += candidate_demo_token_count
else:
break
else:
raise Exception("why key missing in the dict?")
print("number of candidate demonstrations: ", len(candidate_demonstrations))
return candidate_demonstrations
在第四行,调用了步骤一中定义的方法。此外最关键的一行代码是第七行
results_top_n = bm25.get_top_n(tokenized_query, test_methods, n=2)
前两个参数前面已经介绍过,最后一个参数用来决定返回前多少个最相似的序列。
来看一下最终的调用!
DATA_PATH = "sample/"
dds = Dataset(
DATA_PATH,
"1_demo_pool_focal_method.txt" ,
"2_demo_pool_test_method.txt" ,
"3_demo_pool_focal_name.txt" ,
"4_demo_pool_test_name.txt" ,
"5_demo_pool_oracle.txt" ,
)
demoPool : List[Oracle_datapoint] = []
demoPool = get_demo_pool(dds.parse())
candidate_demos = bm25_retrived_demos(query,demoPool)
以上是关于如何使用BM25算法检索出最相关的序列的主要内容,如果未能解决你的问题,请参考以下文章
Elasticsearch实用BM25 -第2部分:BM25算法及其变量
ElasticSearch实战-TF/IDF/BM25分值计算(文本搜索排序分值计算,全文检索算法,文本相似度算法)