选择包含所有给定值的 numpy 数组行

Selecting rows of numpy array that contains all the given values

我有一个 np.array:

matrix = np.array([['A', 'B', 'C'], ['A', 'B', np.nan], ['C', np.nan, np.nan] ])

我想高效地select所有包含给定值的行

samples = ['C', 'A']

但是当我做的时候:

mask = np.isin(matrix, samples)

我明白了

array([[ True, False,  True],
       [ True, False, False],
       [ True, False, False]])

当仅在包含两个值的行中为 True 时,如何有效地获取掩码?

我专注于高效,因为它是一个稀疏的大矩阵。

谢谢大家提前预估。

我的第一个方法是

[np.isin(samples, row).all() for row in matrix]
# [True, False, False]

(但老实说,效率和性能什么的都说不出来...)

如果你想要矢量化的东西,我建议通过将其转换为 3D 并在三维上广播来进行比较。然后对于每个切片,检查每一行以查看是否有任何内容 True。最后,如果我们看到对于每一行,每个元素都是 True,那么这就是我们应该 return.

的结果
In [40]: matrix = np.array([['A', 'B', 'C'], ['A', 'B', np.nan], ['C', np.nan, np.nan] ])

In [41]: samples = ['C', 'A']

In [42]: samples = np.array(samples)

In [43]: mask = matrix[...,None] == samples[None,None]

In [44]: mask
Out[44]:
array([[[False,  True],
        [False, False],
        [ True, False]],

       [[False,  True],
        [False, False],
        [False, False]],

       [[ True, False],
        [False, False],
        [False, False]]])

In [45]: mask = np.any(mask, axis=1)

In [46]: mask
Out[46]:
array([[ True,  True],
       [False,  True],
       [ True, False]])

In [47]: mask = np.all(mask, axis=1)

In [48]: mask
Out[48]: array([ True, False, False])

稍后执行此操作:

# Define data
matrix = np.array([['A', 'B', 'C'], ['A', 'B', np.nan], ['C', np.nan, np.nan] ])
samples = ['C', 'A']

# Solution
mask = np.all(np.any(matrix[...,None] == np.array(samples)[None,None], axis=1), axis=1)

请注意,这可能不适用于大型稀疏矩阵....

这是可能对您有所帮助的伪代码:

idxRows = []
for idx, i in enumerate(mask):
    if True in i:
        idxRows.append(idx)

这将为您提供包含所述样本的所有行的索引。

我终于用上了:

#Filter
test_elements = ['A', 'B']
mask = np.isin(matrix, test_elements)
vec_mask = np.isin(mask.sum(axis=1), [len(test_elements)])
ids = np.where(vec_mask)
existence = matrix[ids]

谢谢大家的帮助。