基于SimCSE和Faiss的文本向量检索实践
Posted 行走的人偶
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于SimCSE和Faiss的文本向量检索实践相关的知识,希望对你有一定的参考价值。
目录
传统的文本检索一般是建立倒排索引,对搜索词的召回结果进行打分排序返回最终结果,但是在海量的数据面前,召回结果页面临着一些挑战。于是就有了基于语义的搜索,即将文本向量化,默认向量包含了文本的语义信息,匹配最近的向量返回结果。
文本的向量表示
文本的向量表示有很多种方式,例如:one-hot,tf-idf,word2vec,或者基于深度学习的sentence-bert,simbert等等,这里尝试采用近年来小有名气的SimCSE,据说在各种任务上面都达到了SOTA的水平,而且支持无监督的训练,这一点就足够吸引我们来试验一下效果了。
1、SimCSE
SimCSE: Simple Contrastive Learning of Sentence Embeddings,可以用来做句向量的表征训练,对于模型本身网上已经有很多介绍了,这里不做太多阐述,只介绍算法几个核心的点以及训练模型的注意事项。
2、支持无监督训
很多句向量模型都是有监督的,构建正样本和负样本的数据集,这种方式大家比较容易想到,也比较好理解,结果的好坏也和标注数据息息相关。而无监督的模型就比较少见,SimCSE无监督训练构思比较巧妙,也很简单,给人一种大道至简的感觉。
SimCSE的入手点是一个batch的数据,没条样本数据自己和自己构成正样本,自己和其他数据构成负样本,负样本这里是没有问题,但是正样本自己和自己就缺少泛化能力,为了使正样本之间有所差异并且保持相似性,可以使用数据增强的方法,例如增加噪音或者GAN生成等方式,当然这就变成了另外的话题了。作者利用不同的dropout来产生正样本,同一个batch里面每个数据的dropout不一致。
在构建无监督batch的数据时,同一条数据重复一次,即一个batch内最后的数据形式为:[a1,a1,a2,a2.....an,an]。为了方便理解,我们把最终的混淆矩阵画出来(带有+号的是经过drop生成的正样本):
a1 | a1+ | a2 | a2+ | ... | an | an+ | |
a1 | 0 | 1 | 0 | 0 | ... | 0 | 0 |
a1+ | 1 | 0 | 0 | 0 | ... | 0 | 0 |
a2 | 0 | 0 | 0 | 1 | ... | 0 | 0 |
a2+ | 0 | 0 | 1 | 0 | ... | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... |
an | 0 | 0 | 0 | 0 | ... | 0 | 1 |
an+ | 0 | 0 | 0 | 0 | ... | 1 | 0 |
归一化之后,使用相乘即可计算出来向量之间的cosine距离,接着可以计算出loss。
def simcse_loss(y_true, y_pred):
"""用于SimCSE训练的loss
"""
# 构造标签
idxs = K.arange(0, K.shape(y_pred)[0])
idxs_1 = idxs[None, :]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())
# 计算相似度
y_pred = K.l2_normalize(y_pred, axis=1)
similarities = K.dot(y_pred, K.transpose(y_pred))
similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
similarities = similarities * 20
loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)
return K.mean(loss)
3、训练注意事项
学习率设置为1e-5,dropout设置为0.3。
随机选取了1W条博客标题,训练一个epoch,即可得到很好的效果。多增加训练数据,或者训练epoch效果反而会下降。无监督训练确实让人很迷惑啊。
个人感觉SimCSE其实也没有真正学习到句子的语义,一个句子增加一个不字,变成语义相反的句子,计算其相似度还是挺高的,跨越语义鸿沟任道而重远,不知道gpt3/gpt4是不是距离真正的语义越来越近了。
向量检索
向量检索可以用ES,8.0版本听说优化了检索的算法,由于公司的ES还没有升级,也没能验证下大数据集检索的速度怎么样,若是能得到大幅度优化,ES还是首先的,毕竟其全文检索能力还是很强;也可以用Faiss,海量数据表现出色,但是其只是一个包,需要进行二次开发;更可以用milvus,是一个向量数据库,集成了Faiss,支持属性的索引联合搜索,但是也还存在一些坑,待后续版本完善。这里决定采用Faiss,单一字段的检索,Faiss还是能胜任的。
使用Faiss比较简单,但是要选择一个适合自己业务场景的索引类型很重要,现实场景中我们还要考虑数据量,召回率,内存大小,响应时间等因素,可以参考官方文档。
1、精准查找flat
基于穷举,召回率最高,但是速度慢,适合数据量比较小的场景。
2、HNSWx
基于图检索的方法,召回率高,检索速度快,但是构建索引慢,占用内存极大。
3、IVFx
基于倒排索引,IVFx中的x是k-means聚类中心的个数,使用倒排索引的思想减少检索时间。
4、PQx
基于乘积量化,分段检索,然后取交集,检索速度快,内存占用较小,但是召回率有所下降。
5、LSH
基于局部敏感哈希,占用内存小,检索速度快,但是召回率低。
考虑到生产环境的内存占用,检索的数据量在5000W+,这里使用网上比较推荐的IVFxPQy的索引方式,比较中规中矩的一种方法。
对博客标题进行向量检索
数据向量化
利用上面训练好的SimCSE模型,使用批推理的方式,将所有数据向量化,主要是为了利用批推理加快处理速度。
def data_to_vec(self):
"""数据向量化"""
data_list = []
batch_data = []
count = 0
with open(self.blog_data_path) as file:
for line in file:
terms = line.strip().split("\\t")
article_id, title = terms[0], terms[1]
count += 1
if count % 1000 == 0:
print(f"processed: count")
if len(batch_data) >= self.model_predict.model_config.batch_size:
vecs = self.model_predict.predict_batch(batch_data)
batch_data.clear()
data_list.extend(vecs)
batch_data.append(title)
if len(batch_data):
vecs = self.model_predict.predict_batch(batch_data)
data_list.extend(vecs)
batch_data.clear()
xb = np.array(data_list, dtype=np.float32)
np.save(self.data_vec_path, xb)
print("save vector.")
构建索引
IVFx中x取100,PQy中y取16。
def create_index(self):
"""创建索引"""
param = 'IVF100,PQ16'
measure = faiss.METRIC_L2
index = faiss.index_factory(self.model_predict.model_config.output_units, param, measure)
xb = np.load(self.data_vec_path + ".npy")
start_time = time.time()
print("create index...")
index.train(xb)
index.add(xb)
faiss.write_index(index, self.index_path)
end_time = time.time()
print(f"index created.cost:end_time-start_time")
文本检索
首先将句子转换成向量,再进行检索,取topk,然后计算余弦相似度,进行重新排序,输出结果。
def search(self, query, topk=5):
""""搜索"""
vec = self.model_predict.predict(query.lower())
d, idx_list = self.index.search(np.array([vec]), topk)
result_list = []
for idx in range(len(idx_list[0])):
data = self.data_list[idx_list[0][idx]]
new_distance = self._cos_distinct(data, query)
result_list.append((data, new_distance))
sorted_result = sorted(result_list, key=lambda term: term[1], reverse=True)
return sorted_result
测试检索
输入检索句子(就拿本文的标题来检索一下吧):
基于SimCSE和Faiss的文本向量检索实践
输出:
基于Lucene的全文检索实践
基于mongodb的地理检索实现
基于mongodb的地理检索实现
PostgreSQL的FTI与中文全文索引的实践
docker + laravel项目使用elasticsearch进行全文检索功能
再看一下CSDN官网的检索结果:
官网的结果基于词的全文检索,本文给的结果有那么一点儿语义的味道,但是距离真语义还是有一定差距,现在这个阶段的话倒是可以作为一个召回源对全文检索进行补充。
下课咯。
以上是关于基于SimCSE和Faiss的文本向量检索实践的主要内容,如果未能解决你的问题,请参考以下文章
效果提升28个点 基于领域预训练和对比学习SimCSE的语义检索
EMNLP 2021SimCSE:句子嵌入的简单对比学习 && CVPR 2021理解对比学习损失函数及温度系数