将 numpy 数组添加到堆队列
Adding numpy array to a heap queue
有人可以解释为什么以下代码会导致 ValueError 吗?
import heapq
import numpy as np
a = np.ones((2, 2), dtype=int)
states = []
heapq.heappush(states, (0, a))
heapq.heappush(states, (0, a.copy()))
错误信息是:
Traceback (most recent call last):
File "x.py", line 8, in <module>
heapq.heappush(states, (0, a.copy()))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
运行 它没有将 a.copy()
添加到堆中工作正常,second/subsequent 一个由于某种原因是一个问题。我确实知道 [True, False, True]
数组有一个未知的真值方面,并且不可能从中确定单个 True
或 False
,但为什么 [=16= 】 需要那样做吗?特别是只有第二种情况?
TL;DR: 因为 numpy 数组如果包含多个元素则无法转换为布尔值。
关于堆的一些信息:
堆 "order" 它们的内容(因此项目必须实现 <
但这是一个实现细节)。
但是,您可以通过为项目创建 tuple
来将项目插入 heap
,其中第一个元素是某个值,第二个元素是数组。
比较元组首先检查第一项是否相等,如果相等则检查第二项是否相等,依此类推直到不相等,然后检查它是否更小(当操作为 <
) 或更大(对于 >
)。然而,元组是在 C 中实现的,==
检查与 Python 中的检查有点不同。它使用 PyObject_RichCompareBool
。特别是 "note" 在这里很重要
If o1
and o2
are the same object, PyObject_RichCompareBool()
will always return 1 for Py_EQ
and 0 for Py_NE
.
现在让我们转到 numpy 数组:
如果 numpy.array
包含多个项目,则无法将其转换为 bool
:
>>> arr = np.array([1,2,3])
>>> bool(arr)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
if
检查将条件隐式转换为布尔值:
>>> if arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
即使比较 numpy-arrays 它们仍然是 numpy 数组:
>>> arr > arr
array([False, False], dtype=bool)
>>> arr == arr
array([ True, True], dtype=bool)
所以这些不能用==
计算:
>>> if arr == arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
所以您不能将具有多个元素的 numpy-arrays 转换为布尔值!然而现在有趣的部分来了:heapq
-模块使用 PyObject_RichCompareBool()
所以它可以检查两个数组是否相等但是 当且仅当 它们是相同的!
这就是为什么它可以处理多次传入的同一个数组,但在复制时却失败了:
>>> arr is arr
True
>>> arr is arr.copy()
False
当前答案没有讨论解决方案。
一种解决方法是将 np.array
转换为 tuple
,然后再将其输入 heapq
处理的 states
列表。
tuple
函数只适用于数组的第一维,所以一个快速的方法是先展平数组:
states = []
heapq.heappush(states, (0, tuple(a.flat)))
heapq.heappush(states, (0, tuple(a.flat)))
或者,可以通过使用
(来自这个answer):
def totuple(a):
try:
return tuple(totuple(i) for i in a)
except TypeError:
return a
有人可以解释为什么以下代码会导致 ValueError 吗?
import heapq
import numpy as np
a = np.ones((2, 2), dtype=int)
states = []
heapq.heappush(states, (0, a))
heapq.heappush(states, (0, a.copy()))
错误信息是:
Traceback (most recent call last):
File "x.py", line 8, in <module>
heapq.heappush(states, (0, a.copy()))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
运行 它没有将 a.copy()
添加到堆中工作正常,second/subsequent 一个由于某种原因是一个问题。我确实知道 [True, False, True]
数组有一个未知的真值方面,并且不可能从中确定单个 True
或 False
,但为什么 [=16= 】 需要那样做吗?特别是只有第二种情况?
TL;DR: 因为 numpy 数组如果包含多个元素则无法转换为布尔值。
关于堆的一些信息:
堆 "order" 它们的内容(因此项目必须实现 <
但这是一个实现细节)。
但是,您可以通过为项目创建 tuple
来将项目插入 heap
,其中第一个元素是某个值,第二个元素是数组。
比较元组首先检查第一项是否相等,如果相等则检查第二项是否相等,依此类推直到不相等,然后检查它是否更小(当操作为 <
) 或更大(对于 >
)。然而,元组是在 C 中实现的,==
检查与 Python 中的检查有点不同。它使用 PyObject_RichCompareBool
。特别是 "note" 在这里很重要
If
o1
ando2
are the same object,PyObject_RichCompareBool()
will always return 1 forPy_EQ
and 0 forPy_NE
.
现在让我们转到 numpy 数组:
如果 numpy.array
包含多个项目,则无法将其转换为 bool
:
>>> arr = np.array([1,2,3])
>>> bool(arr)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
if
检查将条件隐式转换为布尔值:
>>> if arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
即使比较 numpy-arrays 它们仍然是 numpy 数组:
>>> arr > arr
array([False, False], dtype=bool)
>>> arr == arr
array([ True, True], dtype=bool)
所以这些不能用==
计算:
>>> if arr == arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
所以您不能将具有多个元素的 numpy-arrays 转换为布尔值!然而现在有趣的部分来了:heapq
-模块使用 PyObject_RichCompareBool()
所以它可以检查两个数组是否相等但是 当且仅当 它们是相同的!
这就是为什么它可以处理多次传入的同一个数组,但在复制时却失败了:
>>> arr is arr
True
>>> arr is arr.copy()
False
当前答案没有讨论解决方案。
一种解决方法是将 np.array
转换为 tuple
,然后再将其输入 heapq
处理的 states
列表。
tuple
函数只适用于数组的第一维,所以一个快速的方法是先展平数组:
states = []
heapq.heappush(states, (0, tuple(a.flat)))
heapq.heappush(states, (0, tuple(a.flat)))
或者,可以通过使用 (来自这个answer):
def totuple(a):
try:
return tuple(totuple(i) for i in a)
except TypeError:
return a