对嵌套列表使用 np.where

Using np.where for nested lists

我正在尝试对嵌套列表使用 np.where() 函数。 我想在嵌套列表的第一层找到一个给定条件的索引。

比如我有下面的代码

arr = [[1,1], [2,2],[3,3]]
a = np.where(arr == [2,2])

那么理想情况下我想将 return 'a' 编码为 1。 由于 [2,2] 在嵌套列表的索引 1 中。

但是,结果我只得到一个空数组。

当然,我可以通过实现外部for循环使其更容易工作,例如

for n in range(len(arr)):
   if arr[n] == [2,2]:
      a = n

但我想在函数 np.where 中简单地实现它(在这里写下整个代码)。

有办法吗?

最好的解决方案是@Michael Szczesny 提到的,但是使用 np.where 你也可以这样做:

a = np.where(np.array(arr) == [2, 2])[0]
resulted_ind = np.where(np.bincount(a) == 2)[0]  # --> [1]

好吧,您可以编写自己的函数来执行此操作:

你需要

  • 找到与您查找的内容相同的每一行
  • 获取找到的行的索引(您可以使用 where):

numpy压缩

您可以使用压缩运算符查看每一行是否满足条件。如:

np_arr = np.array(
    [1, 2, 3, 4, 5]
)
print(np_arr < 3)

这将 return 一个布尔值,其中每个元素都是 TrueFalse 满足条件的:

[ True  True False False False]

对于二维数组,您将得到一个二维布尔数组:

to_find = np.array([2, 2])
np_arr = np.array(
    [
        [1, 1],
        [2, 2],
        [3, 3],
        [2, 2]
    ]
)
print(np_arr == to_find)

结果是:

[[False False]
 [ True  True]
 [False False]
 [ True  True]]

现在我们正在寻找具有 all True 值的行。所以我们可以使用ndarrayall方法。我们将向所有人提供我们期待的轴心。 X、Y 或两者。我们要查看 x 轴:

to_find = np.array([2, 2])
np_arr = np.array(
    [
        [1, 1],
        [2, 2],
        [3, 3],
        [2, 2]
    ]
)
print((np_arr == to_find).all(axis=1))

结果是:

[False  True False  True]

获取 Trues

的索引

最后您要查找值为 True:

的索引
np.where((np_arr == to_find).all(axis=1))

结果将是:

(array([1, 3]),)

numpy 在 Python 中运行,因此您可以同时使用基本的 Python 列表和 numpy 数组(更像是 MATLAB 矩阵)

列表列表:

In [43]: alist = [[1,1], [2,2],[3,3]]

一个列表有一个index方法,它针对列表的每个元素进行测试(这里的元素是2个元素列表):

In [44]: alist.index([2,2])
Out[44]: 1
In [45]: alist.index([2,3])
Traceback (most recent call last):
  Input In [45] in <cell line: 1>
    alist.index([2,3])
ValueError: [2, 3] is not in list

alist==[2,2]returnsFalse,因为列表和[2,2]列表不一样

如果我们从该列表创建一个数组:

In [46]: arr = np.array(alist)
In [47]: arr
Out[47]: 
array([[1, 1],
       [2, 2],
       [3, 3]])

我们可以进行 == 测试 - 但它会比较数字元素。

In [48]: arr == np.array([2,2])
Out[48]: 
array([[False, False],
       [ True,  True],
       [False, False]])

此比较的基础是 broadcasting 的概念,允许它比较 (3,2) 数组与 (2,)(2d 与 1d)。这是微不足道的,但它可能要复杂得多。

要查找所有值为 True 的行,请使用:

In [50]: (arr == np.array([2,2])).all(axis=1)
Out[50]: array([False,  True, False])

where 在该数组中找到 True(结果是一个包含 1 个数组的元组):

In [51]: np.where(_)
Out[51]: (array([1]),)

在 Octave 中相当于:

>> arr = [[1,1];[2,2];[3,3]]
arr =

   1   1
   2   2
   3   3

>> all(arr == [2,2],2)
ans =

  0
  1
  0
>> find(all(arr == [2,2],2))
ans =  2