numpy.where 是如何运作的?
how does numpy.where work?
我可以理解以下 numpy 行为。
>>> a
array([[ 0. , 0. , 0. ],
[ 0. , 0.7, 0. ],
[ 0. , 0.3, 0.5],
[ 0.6, 0. , 0.8],
[ 0.7, 0. , 0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. , 0.7, 0.5, 0.8, 0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7, 0.7, 0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))
我理解 0.7、0.7 和 0.8 是 a[1,1]、a[3,2] 和 a[4,0] 所以我得到元组 (array[1,3,4] and array[1,2,0])
每个数组由 0 组成和这三个元素的第一个索引。然后我尝试了其他例子,看看我的理解是否正确。
>>> np.where(a == [0.3])
(array([2]), array([1]))
0.3 在 [2,1] 中,所以结果符合我的预期。然后我试了
>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)
??我希望看到 (array([2,2]),array([2,3]))。为什么我会看到上面的输出?
>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))
第二个结果我也看不懂。有人可以向我解释一下吗?谢谢。
首先要意识到 np.where(a == [whatever])
只是向您显示 a == [whatever]
为 True 的索引。因此,您可以通过查看 a == [whatever]
的值来获得提示。在你的情况下 "works":
>>> a == [0.7, 0.7, 0.8]
array([[False, False, False],
[False, True, False],
[False, False, False],
[False, False, True],
[ True, False, False]], dtype=bool)
你并没有得到你认为的那样。您认为这是分别请求每个元素的索引,但实际上它是在行 的相同位置获取值匹配 的位置。基本上这个比较所做的就是说 "for each row, tell me whether the first element is 0.7, whether the second is 0.7, and whether the third is 0.8"。然后 returns 那些匹配位置的索引。换句话说,比较是在整行之间进行的,而不仅仅是单个值。对于你的最后一个例子:
>>> a == [0.8,0.7,0.7]
array([[False, False, False],
[False, True, False],
[False, False, False],
[False, False, False],
[False, False, False]], dtype=bool)
你现在得到了不同的结果。它不要求 "the indices where a
has value 0.8",它只要求在行的开头有 0.8 的索引 -- 同样在后面两个位置中的任何一个有 0.7 .
这种 row-wise 比较只有在您比较的值与 a
的单行形状匹配时才能进行。因此,当您使用 two-element 列表尝试它时,它 returns 是一个空集,因为它试图将列表作为标量值与数组中的各个值进行比较。
结果是您不能在值列表上使用 ==
并期望它只告诉您任何值出现的位置。相等性将按值 和位置 匹配(如果您比较的值与数组的一行形状相同),或者它将尝试将整个列表作为标量进行比较(如果形状不匹配)。如果您想独立搜索值,则需要执行 Khris 在评论中建议的操作:
np.where((a==0.3)|(a==0.5))
也就是说,您需要对单独的值进行两次(或更多次)单独比较,而不是对值列表进行一次比较。
我可以理解以下 numpy 行为。
>>> a
array([[ 0. , 0. , 0. ],
[ 0. , 0.7, 0. ],
[ 0. , 0.3, 0.5],
[ 0.6, 0. , 0.8],
[ 0.7, 0. , 0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. , 0.7, 0.5, 0.8, 0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7, 0.7, 0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))
我理解 0.7、0.7 和 0.8 是 a[1,1]、a[3,2] 和 a[4,0] 所以我得到元组 (array[1,3,4] and array[1,2,0])
每个数组由 0 组成和这三个元素的第一个索引。然后我尝试了其他例子,看看我的理解是否正确。
>>> np.where(a == [0.3])
(array([2]), array([1]))
0.3 在 [2,1] 中,所以结果符合我的预期。然后我试了
>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)
??我希望看到 (array([2,2]),array([2,3]))。为什么我会看到上面的输出?
>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))
第二个结果我也看不懂。有人可以向我解释一下吗?谢谢。
首先要意识到 np.where(a == [whatever])
只是向您显示 a == [whatever]
为 True 的索引。因此,您可以通过查看 a == [whatever]
的值来获得提示。在你的情况下 "works":
>>> a == [0.7, 0.7, 0.8]
array([[False, False, False],
[False, True, False],
[False, False, False],
[False, False, True],
[ True, False, False]], dtype=bool)
你并没有得到你认为的那样。您认为这是分别请求每个元素的索引,但实际上它是在行 的相同位置获取值匹配 的位置。基本上这个比较所做的就是说 "for each row, tell me whether the first element is 0.7, whether the second is 0.7, and whether the third is 0.8"。然后 returns 那些匹配位置的索引。换句话说,比较是在整行之间进行的,而不仅仅是单个值。对于你的最后一个例子:
>>> a == [0.8,0.7,0.7]
array([[False, False, False],
[False, True, False],
[False, False, False],
[False, False, False],
[False, False, False]], dtype=bool)
你现在得到了不同的结果。它不要求 "the indices where a
has value 0.8",它只要求在行的开头有 0.8 的索引 -- 同样在后面两个位置中的任何一个有 0.7 .
这种 row-wise 比较只有在您比较的值与 a
的单行形状匹配时才能进行。因此,当您使用 two-element 列表尝试它时,它 returns 是一个空集,因为它试图将列表作为标量值与数组中的各个值进行比较。
结果是您不能在值列表上使用 ==
并期望它只告诉您任何值出现的位置。相等性将按值 和位置 匹配(如果您比较的值与数组的一行形状相同),或者它将尝试将整个列表作为标量进行比较(如果形状不匹配)。如果您想独立搜索值,则需要执行 Khris 在评论中建议的操作:
np.where((a==0.3)|(a==0.5))
也就是说,您需要对单独的值进行两次(或更多次)单独比较,而不是对值列表进行一次比较。