如何测试列表的所有项目都是不相交的?

How to test all items of a list are disjoint?

给定一个包含多个可迭代对象的列表,我想测试是否所有项目都是 disjoint

two sets are said to be disjoint if they have no element in common

示例:

iterables = ["AB", "CDE", "AF"]
all_disjoint(iterables)
# False

iterables = ["AB", "CDE", "FG"]
all_disjoint(iterables)
# True

Python 集合有一个有效的 isdisjoint 方法,但它是为一次测试两个元素而设计的。一种方法是将此方法应用于每个成对的元素组:

import itertools as it


def pairwise_(iterable):
    """s -> (s0,s1), (s1,s2), (s2,s3), ..., (sn,s0)"""
    # Modified: the last element wraps back to the first element.
    a, b = it.tee(iterable, 2)
    first = next(b, None)
    b = it.chain(b, [first])
    return zip(a, b)


def all_disjoint(x):
    return all((set(p0).isdisjoint(set(p1))) for p0, p1 in pairwise_(x))

这里我修改了pairwise itertools recipe最后一次附上第一个元素。然而,这并不完全正确,因为它只测试相邻项目,而不是针对列表中的所有其他项目测试每个项目。我想用更少的代码更优雅地测试所有元素。有更简单的方法吗?

IIUC,你可以获取你的字符串列表,将它们组合起来,然后检查组合长度是否等于该字符串的等效集合长度。

您可以使用 ''.join 连接您的字符串并定义您的函数:

def all_disjoint(iterables):
    total = ''.join(iterables)
    return len(total) == len(set(total))

现在,测试:

all_disjoint(['AB', 'CDE', 'AF'])
# False

all_disjoint(['AB', 'CDE', 'FG'])
# True

首先,set(list('AB')) 会导致集合 {'A', 'B'}

其次,通过枚举 s 然后使用 for s2 in s[n+1:] 只查看上对角线,避免了将值与自身或另一对进行比较的需要。例如,如果 s = ['A', 'B', 'C'],则 [(s1, s2) for n, s1 in enumerate(s) for s2 in s[n+1:]] 将导致:[('A', 'B'), ('A', 'C'), ('B', 'C')]。如果要从 itertools.

导入 combinations,这相当于 list(combinations(s, 2)) 的结果

鉴于以上情况,我使用any生成器来比较每个子集之间是否缺少任何交集。

由于 any 构造,它会在第一次观察到共同元素时短路,避免计算每一对。

s = ['AB', 'CDE', 'AF']
>>> not any(set(list(s1)).intersection(set(list(s2))) 
            for n, s1 in enumerate(s) for s2 in s[n+1:])
False

s = ['AB', 'CDE', 'FG']
>>> not any(set(list(s1)).intersection(set(list(s2))) 
            for n, s1 in enumerate(s) for s2 in s[n+1:])
True

我为其他感兴趣的人添加这些答案。

方法 1:我意识到这可以通过多重集 (Counter) 来完成。

import itertools as it
import collections as ct


def all_disjoint(iterables):
    return all(not v-1 for v in ct.Counter(it.chain.from_iterable(iterables)).values())

方法 2:从 more_itertools library, more_itertools.unique_to_each 产生每个可迭代对象的所有唯一项。以下代码将结果的长度与原始可迭代对象的长度进行比较:

import more_itertools as mit

def all_disjoint(iterables):
    return all(len(x) == len(y) for x, y in zip(iterables, mit.unique_to_each(*iterables)))

鉴于您所说的要测试每个项目是否与 所有 其他项目不相交,我认为这符合您的要求:

import itertools as it

def all_disjoint(x):
    return all((set(p0).isdisjoint(set(p1))) for p0, p1 in it.combinations(x, 2))

iterables = ['AB', 'CDE', 'AF']
print(all_disjoint(iterables))  # -> False

iterables = ['AB', 'CDE', 'FG']
print(all_disjoint(iterables))  # -> True

# your code gives different answer on this one 
# (because it doesn't check what you want)
iterables = ['AB', 'CDE', 'AH', 'FG']
print(all_disjoint(iterables))  # -> False