明月机器学习系列030:特殊二分图的最优匹配算法
Posted 红楼明月
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了明月机器学习系列030:特殊二分图的最优匹配算法相关的知识,希望对你有一定的参考价值。
1. 缘起
最近开发文档识别与比对,经常遇到的一个问题就是谁跟谁应该配对在一起,例如:
两个页面上的文本行,哪行跟哪行应该是对应的?
两份文档中都有若干个表格,哪个表格跟哪个表格应该是对应的?
两个表格都会包含若干的单元格,这些单元格哪个跟哪个是对应的?
开始时,想得比较简单,因为看上去问题也不复杂嘛。
2. 算法的第一个版本
把问题抽象一下,其实不管是单元格,表格,还是文本行都可以看成是一个个的元素,于是我们的问题就成了在两个有序的序列中寻找一个最优的匹配,每个元素最多能跟一个元素进行匹配(可以没有匹配),如下图:
上图显示的就是左右两个集合中的一种匹配,但是这种匹配的形式跟我们的要求是不符合的。我们的场景下,我们匹配的左右两边并不是无序集合,而是两个有序的序列。例如从上图来理解,如果1和6匹配上之后,2则只能和6后面的7或者8进行匹配,所以上图的4这个元素开倒车不符合规则(如果把4和5这两个元素之间的边去掉的话,则是满足条件)。
定义:边就是两个之间的连线。
2.1 算法的目标
我们既然要找到最优的匹配,但是怎么才算是最优呢?这就是要求我们先定义一个数值指标,以此来衡量优劣。这也比较简单,对每条连线的权重求和,以此作为衡量指标。连线的权重则采用两个元素之间的相似性,如果这两个元素是两行文本,我们可以直接使用编辑距离计算相似性(至于那种距离更加合适,就得看具体的场景了。余弦距离,杰卡德距离,甚至语义距离等都可以,只要合适)。
简单说就是,对元素之间的相似性得分进行求和。
2.2 算法思路
有了目标,那看起来就比较简单了,直接从左边元素随机取一个子集,然后再右边元素也随机取一个相同元素个数的子集,再按顺序对应上,就能计算一个得分指标。
这个思路简单,实现也简单,直接上代码:
def less_match(self):
"""当数据比较小时,可以使用穷举匹配"""
max_score = 0
items = np.array([])
max_num = min(len(self.seq1), len(self.seq2))
for num in range(max_num, 0, -1):
tmp_score, tmp_items = self.match_num(num)
if tmp_score >= max_score:
# 两个空元素配对在一起,去掉之后得分不会改变,但是空元素不应该配对在一起
items = tmp_items
max_score = tmp_score
else:
# 如果不能产生更好的值,则退出
break
return items
def match_num(self, num):
"""从两个序列中分别提取num个元素进行匹配"""
# print('match num: ', num)
max_score = 0
comb_match = None
# 提取两个序列的下标子集合
comb1 = combinations(range(len(self.seq1)), num)
comb2 = list(combinations(range(len(self.seq2)), num))
for comb_i in comb1:
for comb_j in comb2:
tmp_score = self.cal_comb_score(comb_i, comb_j)
if tmp_score is None:
continue
if tmp_score >= max_score:
max_score = tmp_score
comb_match = (comb_i, comb_j)
# 生成配对items
if comb_match is None:
return 0, []
items = np.array((comb_match[0], comb_match[1])).T
return max_score, items
def cal_comb_score(self, comb_i, comb_j):
"""计算集合得分"""
where = (np.array(comb_i, dtype=int), np.array(comb_j, dtype=int))
scores = self.scores[where]
if self.min_score is not None and np.min(scores) < self.min_score:
return None
return np.sum(scores)
暴力出奇迹(指时间),很快就完成了第一个版本。这个版本其实已经做了部分的剪枝,已经部分计算已经提前,例如元素之间的距离就是预先计算好放到scores中。
但是显然这个版本存在巨大的性能问题。
3. 优化版本
上面的算法在数据量小的时候,还没有问题,但是数据量稍大一点,因为取集合的方式是指数级的,想不废都难。
3.1 剪枝优化
剪枝1. 在我们的场景中,相似度得分大于0,但是其值却很小的边通常是没有意义的,这样我们就可以通过阈值参数直接过滤掉这部分的边。
剪枝2. 仔细分析上面的暴力算法,就会发现其实很多计算是多余的,因为在我们的场景中,一个元素通常只会和另一个序列中附近的元素产生联系,和位置相差比较远的元素产生联系的可能性是很小的,但是在计算编辑距离时,却有可能联系在一起。例如左右两边的序列都有50个元素,左边的第一个元素值恰好和右边元素的最后一个元素的值完全相同,这时他们这两个元素的相似性得分最大,但是这基本是不可能的。于是我们可以考虑将位置因素整合到权重得分上。
# 剪枝:其值却很小的边通常是没有意义的
# self.min_score: 这个是算法的参数,可以根据不同的场景选择不同的阈值
where_i, where_j = np.where(self.scores > self.min_score)
len_j = len(where_j)
# 优化得分: 将位置影响整合到边的权重上
for j, val_j in enumerate(where_j):
# 正常来说,where_j是按顺序排序的
# 如果前面有比当前值大,或者后面有比当前值小,这两种情况都是不常见的,可以减少其权重
err_num = np.count_nonzero(where_j[:j] > val_j)
err_num += np.count_nonzero(where_j[j:] < val_j)
self.scores[where_i[j], val_j] *= (len_j-err_num)/(len_j)
这段代码实现了前面两个剪枝的方式。这里融合位置的方式设计上比较特别,具体可以看代码注释。
剪枝3. 基于第一点的分析,我们还可以在预先计算相似性得分的,只计算相邻位置的元素之间的边的相似性得分,其他的全部置为0。
# 计算得分
len1, len2 = len(seq1), len(seq2)
# 计算窗口的开始和结束位置
start, end = -window, window
if len2 >= len1:
end += len2 - len1
else:
start += len2 - len1
scores = np.zeros((len1, len2))
for i, s1 in enumerate(seq1):
# 一个元素通常只会和另一个序列中相邻的元素产生联系
w_start, w_end = max(0, i+start), min(len2, i+end)
scores[i][w_start:w_end] = [score_func(s1, s2) for s2 in seq2[w_start:w_end]]
3.2 计算优化
元素与元素之间的边的权重已经计算出来了,我们不再使用遍历集合这种暴力的方式,而是先找连通子图,然后在每个连通图的内部删掉一些多余的边,使得每个元素最多只和一个元素联通,并且保证每个联通子图删掉多余的边之后,相似度得分是最高的。简单说就是保证每个联通子图的最优来保证全局最优(当然这不一定成立,但是概率很小,而且即使不是全局最优,也和全局最优相差不多了,所以可以忽略)。连通图计算可以直接使用networkx包中的connected_components函数。
代码行数比较多,就不凑字数了,具体看:https://github.com/ibbd-dev/python-ibbd-algo/blob/master/ibbd_algo/sequence.py
经过这个优化,在我们的场景下,性能基本没什么问题了。
4. 后续思考
后来查资料得知,图论里专门有一种叫二分图,还有相关的算法,不过我们的场景却比较特别,算是一种特殊的二分图吧。研究一下现有的二分图,应该还是有改进空间的。
附录:
源码:https://github.com/ibbd-dev/python-ibbd-algo/blob/master/ibbd_algo/sequence.py
20201230:这个文章上个月就开始写了,只是一直在草稿了,今晚算是补充完整了,自己也梳理了一遍。
以上是关于明月机器学习系列030:特殊二分图的最优匹配算法的主要内容,如果未能解决你的问题,请参考以下文章