我们如何找到布尔数组中沿轴 1 存在“真”值的索引?

How can we find indices where `True` value exists along axis 1 in a boolean array?

这是一个布尔数组

In [102]: arr        
Out[102]: 
array([[0, 1, 0, 0],
       [1, 0, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 1],
       [1, 0, 0, 0],
       [1, 0, 0, 0],
       [1, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 1, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 1],
       [0, 0, 0, 1]], dtype=uint8)

我想计算沿轴 1 的索引,其中有一个 1,然后停止并继续下一列,直到我用完所有列。因此,预期的解决方案是:

array([ 1,  4,  5,  6,  0,  7,  8,  2,  9, 10,  3, 11, 12])

为了进一步解释上述结果是如何产生的:我们从第 1 列开始,沿着轴 1 下降,我们在索引 1 处遇到 1,然后在索引 4 等等,直到我们在该列中的位置 6 遇到最后一个 1。所以,我们就此停止,跳过该列的其余部分,因为不会再有 1s,然后继续到第二列,我们在索引 0 处遇到 1 等等,直到我们耗尽所有列。通过将我们到目前为止收集的所有索引放在一起应该给我们结果数组。

我有一个循环的解决方案,但更喜欢矢量化的解决方案。我们如何解决这个问题?任何初步想法都会有很大帮助。

你应该试试 python 中的 numpy 库 它非常有效且易于使用。

v = [[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 0, 0, 1]];
import numpy as np
flattenedArray = np.array(v).ravel()

print flattenedArray
for i in range(len(flattenedArray)):
if flattenedArray[i]>0:
       print i
In [134]: arr=np.array([[0, 1, 0, 0], 
     ...:        [1, 0, 0, 0], 
     ...:        [0, 0, 1, 0], 
     ...:        [0, 0, 0, 1], 
     ...:        [1, 0, 0, 0], 
     ...:        [1, 0, 0, 0], 
     ...:        [1, 0, 0, 0], 
     ...:        [0, 1, 0, 0], 
     ...:        [0, 1, 0, 0], 
     ...:        [0, 0, 1, 0], 
     ...:        [0, 0, 1, 0], 
     ...:        [0, 0, 0, 1], 
     ...:        [0, 0, 0, 1]], dtype=np.uint8)   

看起来 where 在转置上找到了所需的索引:

In [135]: np.where(arr.T)                                                    
Out[135]: 
(array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
 array([ 1,  4,  5,  6,  0,  7,  8,  2,  9, 10,  3, 11, 12]))
In [136]: np.where(arr.T)[1]                                                 
Out[136]: array([ 1,  4,  5,  6,  0,  7,  8,  2,  9, 10,  3, 11, 12])

如上所说。如果你想要一个向量形式的解决方案,那么 numpy 是最好的选择。

import numpy as np
arr = arr.transpose()
y = np.arange(arr.shape[0])
result =  list(map(lambda j: y[j==1], x))   
result = np.concatenate(result)

诀窍是使用单独的数组y根据条件过滤索引。

您可以将 matrix 作为您的输入:

result = [i for i,x in enumerate(matrix.transpose().flatten()) if x == 1]