将 numpy 数组添加到堆队列

Adding numpy array to a heap queue

有人可以解释为什么以下代码会导致 ValueError 吗?

import heapq
import numpy as np

a = np.ones((2, 2), dtype=int)

states = []
heapq.heappush(states, (0, a))
heapq.heappush(states, (0, a.copy()))

错误信息是:

Traceback (most recent call last):
  File "x.py", line 8, in <module>
    heapq.heappush(states, (0, a.copy()))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

运行 它没有将 a.copy() 添加到堆中工作正常,second/subsequent 一个由于某种原因是一个问题。我确实知道 [True, False, True] 数组有一个未知的真值方面,并且不可能从中确定单个 TrueFalse,但为什么 [=16= 】 需要那样做吗?特别是只有第二种情况?

TL;DR: 因为 numpy 数组如果包含多个元素则无法转换为布尔值。


关于堆的一些信息:

堆 "order" 它们的内容(因此项目必须实现 < 但这是一个实现细节)。

但是,您可以通过为项目创建 tuple 来将项目插入 heap,其中第一个元素是某个值,第二个元素是数组。

比较元组首先检查第一项是否相等,如果相等则检查第二项是否相等,依此类推直到不相等,然后检查它是否更小(当操作为 <) 或更大(对于 >)。然而,元组是在 C 中实现的,== 检查与 Python 中的检查有点不同。它使用 PyObject_RichCompareBool。特别是 "note" 在这里很重要

If o1 and o2 are the same object, PyObject_RichCompareBool() will always return 1 for Py_EQand 0 for Py_NE.

现在让我们转到 numpy 数组:

如果 numpy.array 包含多个项目,则无法将其转换为 bool

>>> arr = np.array([1,2,3])
>>> bool(arr)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

if 检查将条件隐式转换为布尔值:

>>> if arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

即使比较 numpy-arrays 它们仍然是 numpy 数组:

>>> arr > arr
array([False, False], dtype=bool)
>>> arr == arr
array([ True,  True], dtype=bool)

所以这些不能用==计算:

>>> if arr == arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

所以您不能将具有多个元素的 numpy-arrays 转换为布尔值!然而现在有趣的部分来了:heapq-模块使用 PyObject_RichCompareBool() 所以它可以检查两个数组是否相等但是 当且仅当 它们是相同的!

这就是为什么它可以处理多次传入的同一个数组,但在复制时却失败了:

>>> arr is arr
True
>>> arr is arr.copy()
False

当前答案没有讨论解决方案。 一种解决方法是将 np.array 转换为 tuple,然后再将其输入 heapq 处理的 states 列表。 tuple 函数只适用于数组的第一维,所以一个快速的方法是先展平数组:

states = []
heapq.heappush(states, (0, tuple(a.flat)))
heapq.heappush(states, (0, tuple(a.flat)))

或者,可以通过使用 (来自这个answer):

def totuple(a):
    try:
        return tuple(totuple(i) for i in a)
    except TypeError:
        return a