合并共享共同元素的列表

Posted

技术标签:

【中文标题】合并共享共同元素的列表【英文标题】:Merge lists that share common elements 【发布时间】:2011-06-18 01:48:54 【问题描述】:

我的输入是一个列表列表。其中一些具有共同的元素,例如。

L = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]

我需要合并所有共享一个共同元素的列表,并重复此过程,只要没有更多具有相同项目的列表。我考虑过使用布尔运算和while循环,但没有想出一个好的解决方案。

最终结果应该是:

L = [['a','b','c','d','e','f','g','o','p'],['k']] 

【问题讨论】:

合并是什么意思?联盟?您能否展示您对示例数据的预期结果? Simplified solution 用于长度为 2 个子列表(以及更多) 【参考方案1】:

您可以将您的列表视为图形的符号,即['a','b','c'] 是一个具有3 个相互连接的节点的图形。您要解决的问题是找到connected components in this graph。

您可以为此使用NetworkX,其优点是几乎可以保证正确:

l = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]

import networkx 
from networkx.algorithms.components.connected import connected_components


def to_graph(l):
    G = networkx.Graph()
    for part in l:
        # each sublist is a bunch of nodes
        G.add_nodes_from(part)
        # it also imlies a number of edges:
        G.add_edges_from(to_edges(part))
    return G

def to_edges(l):
    """ 
        treat `l` as a Graph and returns it's edges 
        to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)]
    """
    it = iter(l)
    last = next(it)

    for current in it:
        yield last, current
        last = current    

G = to_graph(l)
print connected_components(G)
# prints [['a', 'c', 'b', 'e', 'd', 'g', 'f', 'o', 'p'], ['k']]

为了自己有效地解决这个问题,无论如何你都必须将列表转换成图形化的东西,所以你最好从一开始就使用 networkX。

【讨论】:

其实我之后需要这个来创建图表。 @Wistful Jesus:更有理由使用图书馆。 很酷的答案。作为一个让它更短的小建议,to_edges 函数可以替换为izip(part[:-1], part[1:]) connect_components的时间复杂度是多少?【参考方案2】:

算法:

    从列表中取出第一组 A 对于列表中的其他集合 B,如果 B 与 A 有共同元素,则将 B 加入 A;从列表中删除 B 重复 2. 直到不再与 A 重叠 将 A 放入输出 重复 1. 与列表的其余部分

因此您可能希望使用集合而不是列表。下面的程序应该可以做到。

l = [['a', 'b', 'c'], ['b', 'd', 'e'], ['k'], ['o', 'p'], ['e', 'f'], ['p', 'a'], ['d', 'g']]

out = []
while len(l)>0:
    first, *rest = l
    first = set(first)

    lf = -1
    while len(first)>lf:
        lf = len(first)

        rest2 = []
        for r in rest:
            if len(first.intersection(set(r)))>0:
                first |= set(r)
            else:
                rest2.append(r)     
        rest = rest2

    out.append(first)
    l = rest

print(out)

【讨论】:

我喜欢这个答案。对我来说,这个问题感觉就像一个固定的问题。一个小点:优雅的 first, *rest = l 构造仅适用于 Python 3,将其与 first, rest = l[0], l[1:] 交换似乎在 python 2.7 上工作正常【参考方案3】:

我需要对相当大的列表执行数百万次由 OP 描述的聚类技术,因此想确定上面建议的哪种方法既最准确又最高效。

我为上述每种方法对大小从 2^1 到 2^10 的输入列表进行了 10 次试验,对每种方法使用相同的输入列表,并以毫秒为单位测量了上述每种算法的平均运行时间。结果如下:

这些结果帮助我看到了在始终返回正确结果的方法中,@jochen's 是最快的。在那些不能始终返回正确结果的方法中,mak 的解决方案通常不包括所有输入元素(即缺少列表成员列表),并且 braaksma、cmangla 和 asterisk 的解决方案不能保证最大程度地合并.

有趣的是,这两种最快、正确的算法迄今为止获得了前两位的赞成票,按正确的顺序排列。

这是用于运行测试的代码:

from networkx.algorithms.components.connected import connected_components
from itertools import chain
from random import randint, random
from collections import defaultdict, deque
from copy import deepcopy
from multiprocessing import Pool
import networkx
import datetime
import os

##
# @mimomu
##

def mimomu(l):
  l = deepcopy(l)
  s = set(chain.from_iterable(l))
  for i in s:
    components = [x for x in l if i in x]
    for j in components:
      l.remove(j)
    l += [list(set(chain.from_iterable(components)))]
  return l

##
# @Howard
##

def howard(l):
  out = []
  while len(l)>0:
      first, *rest = l
      first = set(first)

      lf = -1
      while len(first)>lf:
          lf = len(first)

          rest2 = []
          for r in rest:
              if len(first.intersection(set(r)))>0:
                  first |= set(r)
              else:
                  rest2.append(r)
          rest = rest2

      out.append(first)
      l = rest
  return out

##
# Nx @Jochen Ritzel
##

def jochen(l):
  l = deepcopy(l)

  def to_graph(l):
      G = networkx.Graph()
      for part in l:
          # each sublist is a bunch of nodes
          G.add_nodes_from(part)
          # it also imlies a number of edges:
          G.add_edges_from(to_edges(part))
      return G

  def to_edges(l):
      """
          treat `l` as a Graph and returns it's edges
          to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)]
      """
      it = iter(l)
      last = next(it)

      for current in it:
          yield last, current
          last = current

  G = to_graph(l)
  return list(connected_components(G))

##
# Merge all @MAK
##

def mak(l):
  l = deepcopy(l)
  taken=[False]*len(l)
  l=map(set,l)

  def dfs(node,index):
      taken[index]=True
      ret=node
      for i,item in enumerate(l):
          if not taken[i] and not ret.isdisjoint(item):
              ret.update(dfs(item,i))
      return ret

  def merge_all():
      ret=[]
      for i,node in enumerate(l):
          if not taken[i]:
              ret.append(list(dfs(node,i)))
      return ret

  result = list(merge_all())
  return result

##
# @cmangla
##

def cmangla(l):
  l = deepcopy(l)
  len_l = len(l)
  i = 0
  while i < (len_l - 1):
    for j in range(i + 1, len_l):
      # i,j iterate over all pairs of l's elements including new
      # elements from merged pairs. We use len_l because len(l)
      # may change as we iterate
      i_set = set(l[i])
      j_set = set(l[j])

      if len(i_set.intersection(j_set)) > 0:
        # Remove these two from list
        l.pop(j)
        l.pop(i)

        # Merge them and append to the orig. list
        ij_union = list(i_set.union(j_set))
        l.append(ij_union)

        # len(l) has changed
        len_l -= 1

        # adjust 'i' because elements shifted
        i -= 1

        # abort inner loop, continue with next l[i]
        break

      i += 1
  return l

##
# @pillmuncher
##

def pillmuncher(l):
  l = deepcopy(l)

  def connected_components(lists):
    neighbors = defaultdict(set)
    seen = set()
    for each in lists:
        for item in each:
            neighbors[item].update(each)
    def component(node, neighbors=neighbors, seen=seen, see=seen.add):
        nodes = set([node])
        next_node = nodes.pop
        while nodes:
            node = next_node()
            see(node)
            nodes |= neighbors[node] - seen
            yield node
    for node in neighbors:
        if node not in seen:
            yield sorted(component(node))

  return list(connected_components(l))

##
# @NicholasBraaksma
##

def braaksma(l):
  l = deepcopy(l)
  lists = sorted([sorted(x) for x in l]) #Sorts lists in place so you dont miss things. Trust me, needs to be done.

  resultslist = [] #Create the empty result list.

  if len(lists) >= 1: # If your list is empty then you dont need to do anything.
      resultlist = [lists[0]] #Add the first item to your resultset
      if len(lists) > 1: #If there is only one list in your list then you dont need to do anything.
          for l in lists[1:]: #Loop through lists starting at list 1
              listset = set(l) #Turn you list into a set
              merged = False #Trigger
              for index in range(len(resultlist)): #Use indexes of the list for speed.
                  rset = set(resultlist[index]) #Get list from you resultset as a set
                  if len(listset & rset) != 0: #If listset and rset have a common value then the len will be greater than 1
                      resultlist[index] = list(listset | rset) #Update the resultlist with the updated union of listset and rset
                      merged = True #Turn trigger to True
                      break #Because you found a match there is no need to continue the for loop.
              if not merged: #If there was no match then add the list to the resultset, so it doesnt get left out.
                  resultlist.append(l)
  return resultlist

##
# @Rumple Stiltskin
##

def stiltskin(l):
  l = deepcopy(l)
  hashdict = defaultdict(int)

  def hashit(x, y):
      for i in y: x[i] += 1
      return x

  def merge(x, y):
      sums = sum([hashdict[i] for i in y])
      if sums > len(y):
          x[0] = x[0].union(y)
      else:
          x[1] = x[1].union(y)
      return x

  hashdict = reduce(hashit, l, hashdict)
  sets = reduce(merge, l, [set(),set()])
  return list(sets)

##
# @Asterisk
##

def asterisk(l):
  l = deepcopy(l)
  results = 
  for sm in ['min', 'max']:
    sort_method = min if sm == 'min' else max
    l = sorted(l, key=lambda x:sort_method(x))
    queue = deque(l)

    grouped = []
    while len(queue) >= 2:
      l1 = queue.popleft()
      l2 = queue.popleft()
      s1 = set(l1)
      s2 = set(l2)

      if s1 & s2:
        queue.appendleft(s1 | s2)
      else:
        grouped.append(s1)
        queue.appendleft(s2)
    if queue:
      grouped.append(queue.pop())
    results[sm] = grouped
  if len(results['min']) < len(results['max']):
    return results['min']
  return results['max']

##
# Validate no more clusters can be merged
##

def validate(output, L):
  # validate all sublists are maximally merged
  d = defaultdict(list)
  for idx, i in enumerate(output):
    for j in i:
      d[j].append(i)
  if any([len(i) > 1 for i in d.values()]):
    return 'not maximally merged'
  # validate all items in L are accounted for
  all_items = set(chain.from_iterable(L))
  accounted_items = set(chain.from_iterable(output))
  if all_items != accounted_items:
    return 'missing items'
  # validate results are good
  return 'true'

##
# Timers
##

def time(func, L):
  start = datetime.datetime.now()
  result = func(L)
  delta = datetime.datetime.now() - start
  return result, delta

##
# Function runner
##

def run_func(args):
  func, L, input_size = args
  results, elapsed = time(func, L)
  validation_result = validate(results, L)
  return func.__name__, input_size, elapsed, validation_result

##
# Main
##

all_results = defaultdict(lambda: defaultdict(list))
funcs = [mimomu, howard, jochen, mak, cmangla, braaksma, asterisk]
args = []

for trial in range(10):
  for s in range(10):
    input_size = 2**s

    # get some random inputs to use for all trials at this size
    L = []
    for i in range(input_size):
      sublist = []
      for j in range(randint(5, 10)):
        sublist.append(randint(0, 2**24))
      L.append(sublist)
    for i in funcs:
      args.append([i, L, input_size])

pool = Pool()
for result in pool.imap(run_func, args):
  func_name, input_size, elapsed, validation_result = result
  all_results[func_name][input_size].append(
    'time': elapsed,
    'validation': validation_result,
  )
  # show the running time for the function at this input size
  print(input_size, func_name, elapsed, validation_result)
pool.close()
pool.join()

# write the average of time trials at each size for each function
with open('times.tsv', 'w') as out:
  for func in all_results:
    validations = [i['validation'] for j in all_results[func] for i in all_results[func][j]]
    linetype = 'incorrect results' if any([i != 'true' for i in validations]) else 'correct results'

    for input_size in all_results[func]:
      all_times = [i['time'].microseconds for i in all_results[func][input_size]]
      avg_time = sum(all_times) / len(all_times)

      out.write(func + '\t' + str(input_size) + '\t' + \
        str(avg_time) + '\t' + linetype + '\n')

对于绘图:

library(ggplot2)
df <- read.table('times.tsv', sep='\t')

p <- ggplot(df, aes(x=V2, y=V3, color=as.factor(V1))) +
  geom_line() +
  xlab('number of input lists') +
  ylab('runtime (ms)') +
  labs(color='') +
  scale_x_continuous(trans='log10') +
  facet_wrap(~V4, ncol=1)

ggsave('runtimes.png')

【讨论】:

【参考方案4】:

我遇到了试图合并具有共同值的列表的相同问题。这个例子可能是你正在寻找的。 它只在列表上循环一次并更新结果集。

lists = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]
lists = sorted([sorted(x) for x in lists]) #Sorts lists in place so you dont miss things. Trust me, needs to be done.

resultslist = [] #Create the empty result list.

if len(lists) >= 1: # If your list is empty then you dont need to do anything.
    resultlist = [lists[0]] #Add the first item to your resultset
    if len(lists) > 1: #If there is only one list in your list then you dont need to do anything.
        for l in lists[1:]: #Loop through lists starting at list 1
            listset = set(l) #Turn you list into a set
            merged = False #Trigger
            for index in range(len(resultlist)): #Use indexes of the list for speed.
                rset = set(resultlist[index]) #Get list from you resultset as a set
                if len(listset & rset) != 0: #If listset and rset have a common value then the len will be greater than 1
                    resultlist[index] = list(listset | rset) #Update the resultlist with the updated union of listset and rset
                    merged = True #Turn trigger to True
                    break #Because you found a match there is no need to continue the for loop.
            if not merged: #If there was no match then add the list to the resultset, so it doesnt get left out.
                resultlist.append(l)
print resultlist

#

resultset = [['a', 'b', 'c', 'd', 'e', 'g', 'f', 'o', 'p'], ['k']]

【讨论】:

这个算法不正确!如果列表类似于 [[0, 2], [1, 8], [1, 4], [2, 8], [2, 6], [3, 5], [6, 9]]那么结果将是 3 个子列表而不是 2 个子列表。 @anirbanBhui 这已被修复 您能否添加另一个条件来检查今天是否是星期三?我只会在星期三合并。【参考方案5】:

我认为这可以通过将问题建模为graph 来解决。每个子列表都是一个节点,并且仅当两个子列表具有某些共同元素时,才与另一个节点共享一条边。因此,合并的子列表基本上是图中的connected component。合并所有这些只是找到所有连接的组件并列出它们的问题。

这可以通过对图的简单遍历来完成。 BFS 和 DFS 都可以使用,但我这里使用的是 DFS,因为它对我来说有点短。

l = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]
taken=[False]*len(l)
l=[set(elem) for elem in l]

def dfs(node,index):
    taken[index]=True
    ret=node
    for i,item in enumerate(l):
        if not taken[i] and not ret.isdisjoint(item):
            ret.update(dfs(item,i))
    return ret

def merge_all():
    ret=[]
    for i,node in enumerate(l):
        if not taken[i]:
            ret.append(list(dfs(node,i)))
    return ret

print(merge_all())

【讨论】:

@duhaime:你能分享一个失败的案例吗? @duhaime:你能分享一个在 Python 3.5 中失败的案例吗? 似乎 OP 预期 L = [['a','b','c','d','e','f','g','o','p'],['k']] 但在 3.5.3 中此代码打印 [['a', 'c', 'b', 'p']]。也许我错过了什么?我上面的帖子使用不同的输入运行随机测试,所以你也可以检查一下...... @duhaime:谢谢!更新了代码以适用于 Python 3.5。【参考方案6】:

作为Jochen Ritzel pointed out,您正在寻找图中的连通分量。以下是在不使用图形库的情况下如何实现它:

from collections import defaultdict

def connected_components(lists):
    neighbors = defaultdict(set)
    seen = set()
    for each in lists:
        for item in each:
            neighbors[item].update(each)
    def component(node, neighbors=neighbors, seen=seen, see=seen.add):
        nodes = set([node])
        next_node = nodes.pop
        while nodes:
            node = next_node()
            see(node)
            nodes |= neighbors[node] - seen
            yield node
    for node in neighbors:
        if node not in seen:
            yield sorted(component(node))

L = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]
print list(connected_components(L))

【讨论】:

【参考方案7】:

你可以使用networkx库,因为这是graph theory和connected components的问题:

import networkx as nx

L = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]

G = nx.Graph()

#Add nodes to Graph    
G.add_nodes_from(sum(L, []))

#Create edges from list of nodes
q = [[(s[i],s[i+1]) for i in range(len(s)-1)] for s in L]

for i in q:

    #Add edges to Graph
    G.add_edges_from(i)

#Find all connnected components in graph and list nodes for each component
[list(i) for i in nx.connected_components(G)]

输出:

[['p', 'c', 'f', 'g', 'o', 'a', 'd', 'b', 'e'], ['k']]

【讨论】:

【参考方案8】:

我想念一个非古怪的版本。我在 2018 年(7 年后)发布它

简单且不稳定的方法:

1)制作笛卡尔积(交叉连接)合并两个共同的 if 元素 2)删除重复

#your list
l=[['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]

#import itertools
from itertools import product, groupby

#inner lists to sets (to list of sets)
l=[set(x) for x in l]

#cartesian product merging elements if some element in common
for a,b in product(l,l):
    if a.intersection( b ):
       a.update(b)
       b.update(a)

#back to list of lists
l = sorted( [sorted(list(x)) for x in l])

#remove dups
list(l for l,_ in groupby(l))

#result
[['a', 'b', 'c', 'd', 'e', 'f', 'g', 'o', 'p'], ['k']]

【讨论】:

【参考方案9】:

我发现 itertools 是合并列表的快速选项,它为我解决了这个问题:

import itertools

LL = set(itertools.chain.from_iterable(L)) 
# LL is 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'k', 'o', 'p'

for each in LL:
  components = [x for x in L if each in x]
  for i in components:
    L.remove(i)
  L += [list(set(itertools.chain.from_iterable(components)))]

# then L = [['k'], ['a', 'c', 'b', 'e', 'd', 'g', 'f', 'o', 'p']]

对于大型集合,按频率从最常见元素到最少元素排序 LL 可以加快速度

【讨论】:

【参考方案10】:

我的尝试。具有实用的外观。

#!/usr/bin/python
from collections import defaultdict
l = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']]
hashdict = defaultdict(int)

def hashit(x, y):
    for i in y: x[i] += 1
    return x

def merge(x, y):
    sums = sum([hashdict[i] for i in y])
    if sums > len(y):
        x[0] = x[0].union(y)
    else:
        x[1] = x[1].union(y)
    return x


hashdict = reduce(hashit, l, hashdict)
sets = reduce(merge, l, [set(),set()])
print [list(sets[0]), list(sets[1])]

【讨论】:

【参考方案11】:

这是一个相当快的解决方案,没有依赖关系。它的工作原理如下:

    为您的每个子目录分配一个唯一的参考编号(在本例中为子列表的初始索引)

    为每个子列表以及每个子列表中的每个项目创建一个引用元素的字典。

    重复以下步骤,直到没有任何变化:

    3a。浏览每个子列表中的每个项目。如果该项目的当前参考编号与其子列表的参考编号不同,则该元素必须是两个列表的一部分。合并两个列表(从引用中删除当前子列表),并将当前子列表中所有项目的引用号设置为新子列表的引用号。

如果此过程没有导致任何更改,那是因为所有元素都恰好是一个列表的一部分。由于工作集的大小在每次迭代中都会减小,因此算法必然会终止。

   def merge_overlapping_sublists(lst):
    output, refs = , 
    for index, sublist in enumerate(lst):
        output[index] = set(sublist)
        for elem in sublist:
            refs[elem] = index
    changes = True
    while changes:
        changes = False
        for ref_num, sublist in list(output.items()):
            for elem in sublist:
                current_ref_num = refs[elem]
                if current_ref_num != ref_num:
                    changes = True
                    output[current_ref_num] |= sublist
                    for elem2 in sublist:
                        refs[elem2] = current_ref_num
                    output.pop(ref_num)
                    break
    return list(output.values())

下面是对这段代码的一组测试:

def compare(a, b):
    a = list(b)
    try:
        for elem in a:
            b.remove(elem)
    except ValueError:
        return False
    return not b

import random
lst = [["a", "b"], ["b", "c"], ["c", "d"], ["d", "e"]]
random.shuffle(lst)
assert compare(merge_overlapping_sublists(lst), ["a", "b", "c", "d", "e"])
lst = [["a", "b"], ["b", "c"], ["f", "d"], ["d", "e"]]
random.shuffle(lst)
assert compare(merge_overlapping_sublists(lst), ["a", "b", "c",, "d", "e", "f"])
lst = [["a", "b"], ["k", "c"], ["f", "g"], ["d", "e"]]
random.shuffle(lst)
assert compare(merge_overlapping_sublists(lst), ["a", "b", "k", "c", "f", "g", "d", "e"])
lst = [["a", "b", "c"], ["b", "d", "e"], ["k"], ["o", "p"], ["e", "f"], ["p", "a"], ["d", "g"]]
random.shuffle(lst)
assert compare(merge_overlapping_sublists(lst), ["k", "a", "c", "b", "e", "d", "g", "f", "o", "p"])    
lst = [["a", "b"], ["b", "c"], ["a"], ["a"], ["b"]]
random.shuffle(lst)
assert compare(merge_overlapping_sublists(lst), ["a", "b", "c"])

注意返回值是一个集合列表。

【讨论】:

【参考方案12】:

在不知道你想要什么的情况下,我决定猜测你的意思:我只想找到每个元素一次。

#!/usr/bin/python


def clink(l, acc):
  for sub in l:
    if sub.__class__ == list:
      clink(sub, acc)
    else:
      acc[sub]=1

def clunk(l):
  acc = 
  clink(l, acc)
  print acc.keys()

l = [['a', 'b', 'c'], ['b', 'd', 'e'], ['k'], ['o', 'p'], ['e', 'f'], ['p', 'a'], ['d', 'g']]

clunk(l)

输出如下:

['a', 'c', 'b', 'e', 'd', 'g', 'f', 'k', 'o', 'p']

【讨论】:

.__class__ == list 看起来非常错误。至少,isinstance(sub, list)。如果只是作为一个原则问题。 (此外,您可以/应该只使用集合而不是具有虚假值的 dict。) @delnan,这两个罪名都有罪:) 根据 OP 的问题,k 也不应该连接到其他组件 @duhaime,嘿,添加该要求的编辑是在我发布答案后添加的。很有启发性的是,我应该先让发帖人写一个更好的问题,而不是回答这个问题。谢谢。【参考方案13】:

这可能是一种更简单/更快的算法,而且似乎运行良好 -

l = [['a', 'b', 'c'], ['b', 'd', 'e'], ['k'], ['o', 'p'], ['e', 'f'], ['p', 'a'], ['d', 'g']]

len_l = len(l)
i = 0
while i < (len_l - 1):
    for j in range(i + 1, len_l):

        # i,j iterate over all pairs of l's elements including new 
        # elements from merged pairs. We use len_l because len(l)
        # may change as we iterate
        i_set = set(l[i])
        j_set = set(l[j])

        if len(i_set.intersection(j_set)) > 0:
            # Remove these two from list
            l.pop(j)
            l.pop(i)

            # Merge them and append to the orig. list
            ij_union = list(i_set.union(j_set))
            l.append(ij_union)

            # len(l) has changed
            len_l -= 1

            # adjust 'i' because elements shifted
            i -= 1

            # abort inner loop, continue with next l[i]
            break

    i += 1

print l
# prints [['k'], ['a', 'c', 'b', 'e', 'd', 'g', 'f', 'o', 'p']]

【讨论】:

【参考方案14】:

简单来说,你可以使用快速查找。

关键是使用两个临时列表。 第一个称为elements,它存储所有组中存在的所有元素。 第二个名为标签。我从 sklearn 的 kmeans 算法中得到了灵感。 'labels' 存储元素的标签或质心。在这里,我简单地将簇中的第一个元素设为质心。最初,这些值从 0 到 length-1,递增。

对于每个组,我在“元素”中得到他们的“索引”。 然后我根据索引获得组标签。 我计算标签的最小值,这将是它们的新标签。 我将 labels for group 中的所有标签替换为新标签。

或者说,对于每次迭代, 我尝试合并两个或多个现有组。 如果组的标签为 0 和 2 我找到了新标签 0,即两者中的最小值。 我将它们替换为 0。

def cluser_combine(groups):
    n_groups=len(groups)

    #first, we put all elements appeared in 'gruops' into 'elements'.
    elements=list(set.union(*[set(g) for g in groups]))
    #and sort elements.
    elements.sort()
    n_elements=len(elements)

    #I create a list called clusters, this is the key of this algorithm.
    #I was inspired by sklearn kmeans implementation.
    #they have an attribute called labels_
    #the url is here:
    #https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html
    #i called this algorithm cluster combine, because of this inspiration.
    labels=list(range(n_elements))


    #for each group, I get their 'indices' in 'elements'
    #I then get the labels for indices.
    #and i calculate the min of the labels, that will be the new label for them.
    #I replace all elements with labels in labels_for_group with the new label.

    #or to say, for each iteration,
    #i try to combine two or more existing groups.
    #if the group has labels of 0 and 2
    #i find out the new label 0, that is the min of the two.
    #i than replace them with 0.
    for i in range(n_groups):

        #if there is only zero/one element in the group, skip
        if len(groups[i])<=1:
            continue

        indices=list(map(elements.index, groups[i]))

        labels_for_group=list(set([labels[i] for i in indices]))
        #if their is only one label, all the elements in group are already have the same label, skip.
        if len(labels_for_group)==1:

            continue

        labels_for_group.sort()
        label=labels_for_group[0]

        #combine
        for k in range(n_elements):
            if labels[k] in labels_for_group[1:]:
                labels[k]=label


    new_groups=[]
    for label in set(labels):
        new_group = [elements[i] for i, v in enumerate(labels) if v == label]
        new_groups.append(new_group)

    return new_groups

我打印出了你的问题的详细结果:

cluser_combine([['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g']])

元素: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'k', 'o', 'p'] 标签: [0、1、2、3、4、5、6、7、8、9] -----------------第 0 组------------- 该组是: ['a', 'b', 'c'] 元素中组的索引 [0, 1, 2] 组合前的标签 [0、1、2、3、4、5、6、7、8、9] 结合... 组合后的标签 [0, 0, 0, 3, 4, 5, 6, 7, 8, 9] -----------------第 1 组------------- 该组是: ['b', 'd', 'e'] 元素中组的索引 [1,3,4] 组合前的标签 [0, 0, 0, 3, 4, 5, 6, 7, 8, 9] 结合... 组合后的标签 [0, 0, 0, 0, 0, 5, 6, 7, 8, 9] --------------------第 2 组---------------------------------------- 该组是: ['k'] --------------------第 3 组---------------------------------------- 该组是: ['o', 'p'] 元素中组的索引 [8, 9] 组合前的标签 [0, 0, 0, 0, 0, 5, 6, 7, 8, 9] 结合... 组合后的标签 [0, 0, 0, 0, 0, 5, 6, 7, 8, 8] --------------------第 4 组---------------------------- 该组是: ['e', 'f'] 元素中组的索引 [4, 5] 组合前的标签 [0, 0, 0, 0, 0, 5, 6, 7, 8, 8] 结合... 组合后的标签 [0, 0, 0, 0, 0, 0, 6, 7, 8, 8] ---------第 5 组------------------------- 该组是: ['p', 'a'] 元素中组的索引 [9, 0] 组合前的标签 [0, 0, 0, 0, 0, 0, 6, 7, 8, 8] 结合... 组合后的标签 [0, 0, 0, 0, 0, 0, 6, 7, 0, 0] -----------------第 6 组------------------------- 该组是: ['d', 'g'] 元素中组的索引 [3, 6] 组合前的标签 [0, 0, 0, 0, 0, 0, 6, 7, 0, 0] 结合... 组合后的标签 [0, 0, 0, 0, 0, 0, 0, 7, 0, 0] ([0, 0, 0, 0, 0, 0, 0, 7, 0, 0], [['a', 'b', 'c', 'd', 'e', 'f', 'g', 'o', 'p'], ['k']])

详情请参考my github jupyter notebook

【讨论】:

【参考方案15】:

这是我的答案。

orig = [['a','b','c'],['b','d','e'],['k'],['o','p'],['e','f'],['p','a'],['d','g'], ['k'],['k'],['k']]

def merge_lists(orig):
    def step(orig): 
        mid = []
        mid.append(orig[0])
        for i in range(len(mid)):            
            for j in range(1,len(orig)):                
                for k in orig[j]:
                    if k in mid[i]:                
                        mid[i].extend(orig[j])                
                        break
                    elif k == orig[j][-1] and orig[j] not in mid:
                        mid.append(orig[j])                        
        mid = [sorted(list(set(x))) for x in mid]
        return mid

    result = step(orig)
    while result != step(result):                    
        result = step(result)                  
    return result

merge_lists(orig)
[['a', 'b', 'c', 'd', 'e', 'f', 'g', 'o', 'p'], ['k']]

【讨论】:

以上是关于合并共享共同元素的列表的主要内容,如果未能解决你的问题,请参考以下文章

基于共同的第一个元素合并二维列表中的元素

合并在R中共享元素的列出的向量

合并至少共享 2 个元素的集合的算法

Leetcode(712)-账户合并

基于共同元素合并 2 个数组

查找两个列表列列表之间的共同元素?