如何避免在 heapq 中使用 _siftup 或 _siftdown
how to avoid using _siftup or _siftdown in heapq
我不知道如何在不使用 _siftup
或 _siftdown
的情况下有效地解决以下问题:
当一个元素乱序时,如何恢复堆不变性?
换句话说,将 heap
中的 old_value
更新为 new_value
,并保持 heap
正常工作。您可以假设堆中只有一个 old_value
。函数定义如下:
def update_value_in_heap(heap, old_value, new_value):
这是我的真实场景,有兴趣的可以看看
你可以想象它是一个小型的自动完成系统。我需要数数
词的频率,并保持前k个max-count词,这
随时准备输出。所以我在这里使用heap
。当一个字
count++,如果它在堆中,我需要更新它。
所有单词和计数都存储在 trie-tree 的叶子和堆中
存储在 trie-tree 的中间节点中。如果你关心这个词
out of heap,不用担心,我可以从trie-tree的叶子节点中获取它。
当用户输入一个词时,它会先从堆中读取然后更新
它。为了更好的性能,我们可以考虑降低更新频率
by 批量更新
那么当一个特定的字数增加时如何更新堆?
这是 _siftup 或 _siftdown 版本的简单示例(不是我的场景):
>>> from heapq import _siftup, _siftdown, heapify, heappop
>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 22 # increase the 8 to 22
>>> i = data.index(old)
>>> data[i] = new
>>> _siftup(data, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 5, 7, 10, 18, 19, 22, 37]
>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 4 # decrease the 8 to 4
>>> i = data.index(old)
>>> data[i] = new
>>> _siftdown(data, 0, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 4, 5, 7, 10, 18, 19, 37]
索引成本为 O(n),更新成本为 O(logn)。 heapify
是另一种解决方案,但是
效率低于 _siftup
或 _siftdown
.
但是_siftup
和_siftdown
是heapq中受保护的成员,所以不建议从外部访问。
那么有没有更好更高效的方法来解决这个问题呢?这种情况的最佳做法?
感谢阅读,非常感谢它能帮助我。 :)
已参考 ,但没有解决我的问题
TL;DR 使用 heapify
.
您必须牢记的一件重要事情是,理论复杂性和性能是两个不同的事物(即使它们是相关的)。换句话说,实施也很重要。渐近复杂性给你一些 下限 你可以将其视为保证,例如 O(n) 中的算法确保在最坏的情况下,你将执行许多指令与输入大小成线性关系。这里有两件重要的事情:
- 常数被忽略,但常数在现实生活中很重要;
- 最坏的情况取决于您考虑的算法,而不仅仅是输入。
根据您考虑的topic/problem,第一点可能非常重要。在某些领域,隐藏在渐近复杂性中的常量是如此之大,以至于您甚至无法构建比常量更大的输入(或者考虑该输入是不现实的)。这不是这里的情况,但这是你必须始终牢记的事情。
鉴于这两个观察结果,您不能真的说:实现 B 比 A 快,因为 A 是从 O(n) 算法派生的,而 B 是从 O(log n) 算法派生的算法。即使这是一个很好的开端一般来说,它并不总是足够的。当所有输入都同样可能发生时,理论复杂性对于比较算法特别有用。换句话说,当你的算法非常通用时。
如果您知道您的用例和输入是什么,您可以直接测试性能。使用测试和渐近复杂度将使您对算法的执行方式有一个很好的了解(在极端情况和任意实际情况下)。
也就是说,让 运行 对以下 class 进行一些性能测试,这些 class 将实现 (there are actually four strategies here, but Invalidate and Reinsert doesn't seem right in your case as you'll invalidate each item as many time as you see a given word). I'll include most of my code so you can double check that I haven't messed up (you can even check the complete notebook):
from heapq import _siftup, _siftdown, heapify, heappop
class Heap(list):
def __init__(self, values, sort=False, heap=False):
super().__init__(values)
heapify(self)
self._broken = False
self.sort = sort
self.heap = heap or not sort
# Solution 1) repair using the knowledge we have after every update:
def update(self, key, value):
old, self[key] = self[key], value
if value > old:
_siftup(self, key)
else:
_siftdown(self, 0, key)
# Solution 2 and 3) repair using sort/heapify in a lazzy way:
def __setitem__(self, key, value):
super().__setitem__(key, value)
self._broken = True
def __getitem__(self, key):
if self._broken:
self._repair()
self._broken = False
return super().__getitem__(key)
def _repair(self):
if self.sort:
self.sort()
elif self.heap:
heapify(self)
# … you'll also need to delegate all other heap functions, for example:
def pop(self):
self._repair()
return heappop(self)
我们可以先检查三种方法是否都有效:
data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
heap = Heap(data[:])
heap.update(8, 22)
heap.update(7, 4)
print(heap)
heap = Heap(data[:], sort_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)
heap = Heap(data[:], heap_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)
然后我们可以运行使用以下函数进行一些性能测试:
import time
import random
def rand_update(heap, lazzy_fix=False, **kwargs):
index = random.randint(0, len(heap)-1)
new_value = random.randint(max_int+1, max_int*2)
if lazzy_fix:
heap[index] = new_value
else:
heap.update(index, new_value)
def rand_updates(n, heap, lazzy_fix=False, **kwargs):
for _ in range(n):
rand_update(heap, lazzy_fix)
def run_perf_test(n, data, **kwargs):
test_heap = Heap(data[:], **kwargs)
t0 = time.time()
rand_updates(n, test_heap, **kwargs)
test_heap[0]
return (time.time() - t0)*1e3
results = []
max_int = 500
nb_updates = 1
for i in range(3, 7):
test_size = 10**i
test_data = [random.randint(0, max_int) for _ in range(test_size)]
perf = run_perf_test(nb_updates, test_data)
results.append((test_size, "update", perf))
perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, heap_fix=True)
results.append((test_size, "heapify", perf))
perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, sort_fix=True)
results.append((test_size, "sort", perf))
结果如下:
import pandas as pd
import seaborn as sns
dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)
sns.lineplot(
data=dtf,
x="heap size",
y="duration (ms)",
hue="method",
)
从这些测试中我们可以看出 heapify
似乎是最合理的选择,它在最坏的情况下具有不错的复杂度:O(n),并且在实践中表现更好。另一方面,研究其他选项可能是个好主意(比如拥有一个专用于该特定问题的数据结构,例如使用箱子将单词放入其中,然后将它们从一个箱子移动到下一个看起来像是一个可能的轨道调查)。
重要说明:这种情况(更新与读取比率 1:1)对 heapify
和 sort
解决方案都是不利的。所以如果你设法有一个k:1的比例,这个结论就更清楚了(你可以把上面代码中的nb_updates = 1
换成nb_updates = k
)。
数据框详细信息:
heap size method duration in ms
0 1000 update 0.435114
1 1000 heapify 0.073195
2 1000 sort 0.101089
3 10000 update 1.668930
4 10000 heapify 0.480175
5 10000 sort 1.151085
6 100000 update 13.194084
7 100000 heapify 4.875898
8 100000 sort 11.922121
9 1000000 update 153.587103
10 1000000 heapify 51.237106
11 1000000 sort 145.306110
@cglacet 的回答是完全错误的,但看起来很合法。
他提供的代码片段完全坏了!它也很难阅读。
_siftup()
在 heapify()
中被调用 n//2 次,因此它本身不能比 _siftup()
快。
回答原问题,没有更好的办法。如果您担心这些方法是私有的,请创建您自己的方法来做同样的事情。
我唯一同意的是,如果长时间不需要从堆中读取,可能有利于懒惰heapify()
一旦你需要它们。问题是你是否应该为此使用堆。
让我们用他的代码片段来解决问题:
heapify()
函数被多次调用以进行“更新”运行。
导致这种情况的错误链如下:
- 他通过了
heap_fix
,但预计heap
sort
也是如此
- 如果
self.sort
总是 False
,那么 self.heap
总是 True
- 他重新定义了
__getitem__()
和__setitem__()
,每次_siftdown()
的_siftup()
赋值或读取某些东西时都会调用它们(注意:这两个在C中没有调用,所以他们使用 __getitem__()
和 __setitem__()
)
- 如果
self.heap
是True
并且正在调用__getitem__()
和__setitem__()
,则每次_siftup()
或调用_repair()
函数siftdown()
交换元素。但是对 heapify()
的调用是在 C 中完成的,所以 __getitem__()
不会被调用,也不会陷入无限循环
- 他重新定义了
self.sort
所以调用它,就像他尝试做的那样,会失败
- 他读过一次,但更新了一个项目
nb_updates
次,而不是他声称的 1:1
我修正了这个例子,我试着尽可能地验证它,但我们都会犯错误。有空自己查一下。
代码
import time
import random
from heapq import _siftup, _siftdown, heapify, heappop
class UpdateHeap(list):
def __init__(self, values):
super().__init__(values)
heapify(self)
def update(self, index, value):
old, self[index] = self[index], value
if value > old:
_siftup(self, index)
else:
_siftdown(self, 0, index)
def pop(self):
return heappop(self)
class SlowHeap(list):
def __init__(self, values):
super().__init__(values)
heapify(self)
self._broken = False
# Solution 2 and 3) repair using sort/heapify in a lazy way:
def update(self, index, value):
super().__setitem__(index, value)
self._broken = True
def __getitem__(self, index):
if self._broken:
self._repair()
self._broken = False
return super().__getitem__(index)
def _repair(self):
...
def pop(self):
if self._broken:
self._repair()
return heappop(self)
class HeapifyHeap(SlowHeap):
def _repair(self):
heapify(self)
class SortHeap(SlowHeap):
def _repair(self):
self.sort()
def rand_update(heap):
index = random.randint(0, len(heap)-1)
new_value = random.randint(max_int+1, max_int*2)
heap.update(index, new_value)
def rand_updates(update_count, heap):
for i in range(update_count):
rand_update(heap)
heap[0]
def verify(heap):
last = None
while heap:
item = heap.pop()
if last is not None and item < last:
raise RuntimeError(f"{item} was smaller than last {last}")
last = item
def run_perf_test(update_count, data, heap_class):
test_heap = heap_class(data)
t0 = time.time()
rand_updates(update_count, test_heap)
perf = (time.time() - t0)*1e3
verify(test_heap)
return perf
results = []
max_int = 500
update_count = 100
for i in range(2, 7):
test_size = 10**i
test_data = [random.randint(0, max_int) for _ in range(test_size)]
perf = run_perf_test(update_count, test_data, UpdateHeap)
results.append((test_size, "update", perf))
perf = run_perf_test(update_count, test_data, HeapifyHeap)
results.append((test_size, "heapify", perf))
perf = run_perf_test(update_count, test_data, SortHeap)
results.append((test_size, "sort", perf))
import pandas as pd
import seaborn as sns
dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)
sns.lineplot(
data=dtf,
x="heap size",
y="duration (ms)",
hue="method",
)
结果
如您所见,使用 _siftdown()
和 _siftup()
的“更新”方法渐进地更快。
您应该知道您的代码是做什么的,以及需要多长时间才能 运行。如果有疑问,您应该检查一下。 @cglaced 检查了执行需要多长时间,但他没有质疑应该需要多长时间。如果他这样做了,他会发现两者不匹配。其他人也为之倾倒。
heap size method duration (ms)
0 100 update 0.219107
1 100 heapify 0.412703
2 100 sort 0.242710
3 1000 update 0.198841
4 1000 heapify 2.947330
5 1000 sort 0.605345
6 10000 update 0.203848
7 10000 heapify 32.759190
8 10000 sort 4.621506
9 100000 update 0.348568
10 100000 heapify 327.646971
11 100000 sort 49.481153
12 1000000 update 0.256062
13 1000000 heapify 3475.244761
14 1000000 sort 1106.570005
处理私有函数
But _siftup and _siftdown are protected member in heapq, so they are not recommended to access from outside.
代码片段很短,因此您可以在将它们重命名为 public 函数后将它们包含在您自己的代码中:
def siftdown(heap, startpos, pos):
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
def siftup(heap, pos):
endpos = len(heap)
startpos = pos
newitem = heap[pos]
# Bubble up the smaller child until hitting a leaf.
childpos = 2*pos + 1 # leftmost child position
while childpos < endpos:
# Set childpos to index of smaller child.
rightpos = childpos + 1
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
childpos = rightpos
# Move the smaller child up.
heap[pos] = heap[childpos]
pos = childpos
childpos = 2*pos + 1
# The leaf at pos is empty now. Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
heap[pos] = newitem
siftdown(heap, startpos, pos)
保持堆不变性
How to restore the heap invariant, when one element is out-of-order?
使用高级堆 API,您可以进行一系列计数更新,然后 运行 heapify(),然后再进行更多堆操作。这可能不足以满足您的需求。
也就是说,heapify()
函数非常 快。有趣的是,list.sort
方法更加优化,对于某些类型的输入可能优于 heapify()
。
更好的数据结构
need to count the frequency of words, and maintain the top k max-count words, which prepare to output at any moment. So I use heap here. When one word count++, I need update it if it is in heap.
考虑使用与堆不同的数据结构。起初,堆似乎很适合这个任务,但是在堆中找到任意条目很慢,即使我们可以用 siftup/siftdown.
快速更新它
相反,考虑保留从单词到单词列表和计数列表中的位置的字典映射。保持这些列表排序只需要在计数增加时交换位置:
from bisect import bisect_left
word2pos = {}
words = [] # ordered by descending frequency
counts = [] # negated to put most common first
def tally(word):
if word not in word2pos:
word2pos[word] = len(word2pos)
counts.append(-1)
words.append(word)
else:
pos = word2pos[word]
count = counts[pos]
swappos = bisect_left(counts, count, hi=pos)
words[pos] = swapword = words[swappos]
counts[pos] = counts[swappos]
word2pos[swapword] = pos
words[swappos] = word
counts[swappos] = count - 1
word2pos[word] = swappos
def topwords(n):
return [(-counts[i], words[i]) for i in range(n)]
内置解决方案
还有另一个“out-of-the-box”解决方案可能会满足您的需要。只需使用 collections.Counter():
>>> from collections import Counter
>>> c = Counter()
>>> for word in 'one two one three two three three'.split():
... c[word] += 1
...
>>> c.most_common(2)
[('three', 3), ('one', 2)]
二叉树或排序容器解决方案
另一种方法是使用二叉树或排序容器。这些具有 O(log n) 插入和删除。他们随时准备以正向或反向顺序迭代,无需额外的计算工作。
这是一个使用 Grant Jenk 的精彩 Sorted Containers 包的解决方案:
from sortedcontainers import SortedSet
from dataclasses import dataclass, field
from itertools import islice
@dataclass(order=True, unsafe_hash=True, slots=True)
class Entry:
count: int = field(hash=False)
word: str
w2e = {} # type: Dict[str, Entry]
ss = SortedSet() # type: Set[Entry]
def tally(word):
if word not in w2e:
entry = w2e[word] = Entry(1, word)
ss.add(entry)
else:
entry = w2e[word]
ss.remove(entry)
entry.count += 1
ss.add(entry)
def topwords(n):
return list(islice(reversed(ss), n)
我不知道如何在不使用 _siftup
或 _siftdown
的情况下有效地解决以下问题:
当一个元素乱序时,如何恢复堆不变性?
换句话说,将 heap
中的 old_value
更新为 new_value
,并保持 heap
正常工作。您可以假设堆中只有一个 old_value
。函数定义如下:
def update_value_in_heap(heap, old_value, new_value):
这是我的真实场景,有兴趣的可以看看
你可以想象它是一个小型的自动完成系统。我需要数数 词的频率,并保持前k个max-count词,这 随时准备输出。所以我在这里使用
heap
。当一个字 count++,如果它在堆中,我需要更新它。所有单词和计数都存储在 trie-tree 的叶子和堆中
存储在 trie-tree 的中间节点中。如果你关心这个词
out of heap,不用担心,我可以从trie-tree的叶子节点中获取它。当用户输入一个词时,它会先从堆中读取然后更新
它。为了更好的性能,我们可以考虑降低更新频率 by 批量更新
那么当一个特定的字数增加时如何更新堆?
这是 _siftup 或 _siftdown 版本的简单示例(不是我的场景):
>>> from heapq import _siftup, _siftdown, heapify, heappop
>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 22 # increase the 8 to 22
>>> i = data.index(old)
>>> data[i] = new
>>> _siftup(data, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 5, 7, 10, 18, 19, 22, 37]
>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 4 # decrease the 8 to 4
>>> i = data.index(old)
>>> data[i] = new
>>> _siftdown(data, 0, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 4, 5, 7, 10, 18, 19, 37]
索引成本为 O(n),更新成本为 O(logn)。 heapify
是另一种解决方案,但是
效率低于 _siftup
或 _siftdown
.
但是_siftup
和_siftdown
是heapq中受保护的成员,所以不建议从外部访问。
那么有没有更好更高效的方法来解决这个问题呢?这种情况的最佳做法?
感谢阅读,非常感谢它能帮助我。 :)
已参考
TL;DR 使用 heapify
.
您必须牢记的一件重要事情是,理论复杂性和性能是两个不同的事物(即使它们是相关的)。换句话说,实施也很重要。渐近复杂性给你一些 下限 你可以将其视为保证,例如 O(n) 中的算法确保在最坏的情况下,你将执行许多指令与输入大小成线性关系。这里有两件重要的事情:
- 常数被忽略,但常数在现实生活中很重要;
- 最坏的情况取决于您考虑的算法,而不仅仅是输入。
根据您考虑的topic/problem,第一点可能非常重要。在某些领域,隐藏在渐近复杂性中的常量是如此之大,以至于您甚至无法构建比常量更大的输入(或者考虑该输入是不现实的)。这不是这里的情况,但这是你必须始终牢记的事情。
鉴于这两个观察结果,您不能真的说:实现 B 比 A 快,因为 A 是从 O(n) 算法派生的,而 B 是从 O(log n) 算法派生的算法。即使这是一个很好的开端一般来说,它并不总是足够的。当所有输入都同样可能发生时,理论复杂性对于比较算法特别有用。换句话说,当你的算法非常通用时。
如果您知道您的用例和输入是什么,您可以直接测试性能。使用测试和渐近复杂度将使您对算法的执行方式有一个很好的了解(在极端情况和任意实际情况下)。
也就是说,让 运行 对以下 class 进行一些性能测试,这些 class 将实现
from heapq import _siftup, _siftdown, heapify, heappop
class Heap(list):
def __init__(self, values, sort=False, heap=False):
super().__init__(values)
heapify(self)
self._broken = False
self.sort = sort
self.heap = heap or not sort
# Solution 1) repair using the knowledge we have after every update:
def update(self, key, value):
old, self[key] = self[key], value
if value > old:
_siftup(self, key)
else:
_siftdown(self, 0, key)
# Solution 2 and 3) repair using sort/heapify in a lazzy way:
def __setitem__(self, key, value):
super().__setitem__(key, value)
self._broken = True
def __getitem__(self, key):
if self._broken:
self._repair()
self._broken = False
return super().__getitem__(key)
def _repair(self):
if self.sort:
self.sort()
elif self.heap:
heapify(self)
# … you'll also need to delegate all other heap functions, for example:
def pop(self):
self._repair()
return heappop(self)
我们可以先检查三种方法是否都有效:
data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
heap = Heap(data[:])
heap.update(8, 22)
heap.update(7, 4)
print(heap)
heap = Heap(data[:], sort_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)
heap = Heap(data[:], heap_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)
然后我们可以运行使用以下函数进行一些性能测试:
import time
import random
def rand_update(heap, lazzy_fix=False, **kwargs):
index = random.randint(0, len(heap)-1)
new_value = random.randint(max_int+1, max_int*2)
if lazzy_fix:
heap[index] = new_value
else:
heap.update(index, new_value)
def rand_updates(n, heap, lazzy_fix=False, **kwargs):
for _ in range(n):
rand_update(heap, lazzy_fix)
def run_perf_test(n, data, **kwargs):
test_heap = Heap(data[:], **kwargs)
t0 = time.time()
rand_updates(n, test_heap, **kwargs)
test_heap[0]
return (time.time() - t0)*1e3
results = []
max_int = 500
nb_updates = 1
for i in range(3, 7):
test_size = 10**i
test_data = [random.randint(0, max_int) for _ in range(test_size)]
perf = run_perf_test(nb_updates, test_data)
results.append((test_size, "update", perf))
perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, heap_fix=True)
results.append((test_size, "heapify", perf))
perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, sort_fix=True)
results.append((test_size, "sort", perf))
结果如下:
import pandas as pd
import seaborn as sns
dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)
sns.lineplot(
data=dtf,
x="heap size",
y="duration (ms)",
hue="method",
)
从这些测试中我们可以看出 heapify
似乎是最合理的选择,它在最坏的情况下具有不错的复杂度:O(n),并且在实践中表现更好。另一方面,研究其他选项可能是个好主意(比如拥有一个专用于该特定问题的数据结构,例如使用箱子将单词放入其中,然后将它们从一个箱子移动到下一个看起来像是一个可能的轨道调查)。
重要说明:这种情况(更新与读取比率 1:1)对 heapify
和 sort
解决方案都是不利的。所以如果你设法有一个k:1的比例,这个结论就更清楚了(你可以把上面代码中的nb_updates = 1
换成nb_updates = k
)。
数据框详细信息:
heap size method duration in ms
0 1000 update 0.435114
1 1000 heapify 0.073195
2 1000 sort 0.101089
3 10000 update 1.668930
4 10000 heapify 0.480175
5 10000 sort 1.151085
6 100000 update 13.194084
7 100000 heapify 4.875898
8 100000 sort 11.922121
9 1000000 update 153.587103
10 1000000 heapify 51.237106
11 1000000 sort 145.306110
@cglacet 的回答是完全错误的,但看起来很合法。
他提供的代码片段完全坏了!它也很难阅读。
_siftup()
在 heapify()
中被调用 n//2 次,因此它本身不能比 _siftup()
快。
回答原问题,没有更好的办法。如果您担心这些方法是私有的,请创建您自己的方法来做同样的事情。
我唯一同意的是,如果长时间不需要从堆中读取,可能有利于懒惰heapify()
一旦你需要它们。问题是你是否应该为此使用堆。
让我们用他的代码片段来解决问题:
heapify()
函数被多次调用以进行“更新”运行。
导致这种情况的错误链如下:
- 他通过了
heap_fix
,但预计heap
sort
也是如此
- 如果
self.sort
总是False
,那么self.heap
总是True
- 他重新定义了
__getitem__()
和__setitem__()
,每次_siftdown()
的_siftup()
赋值或读取某些东西时都会调用它们(注意:这两个在C中没有调用,所以他们使用__getitem__()
和__setitem__()
) - 如果
self.heap
是True
并且正在调用__getitem__()
和__setitem__()
,则每次_siftup()
或调用_repair()
函数siftdown()
交换元素。但是对heapify()
的调用是在 C 中完成的,所以__getitem__()
不会被调用,也不会陷入无限循环 - 他重新定义了
self.sort
所以调用它,就像他尝试做的那样,会失败 - 他读过一次,但更新了一个项目
nb_updates
次,而不是他声称的 1:1
我修正了这个例子,我试着尽可能地验证它,但我们都会犯错误。有空自己查一下。
代码
import time
import random
from heapq import _siftup, _siftdown, heapify, heappop
class UpdateHeap(list):
def __init__(self, values):
super().__init__(values)
heapify(self)
def update(self, index, value):
old, self[index] = self[index], value
if value > old:
_siftup(self, index)
else:
_siftdown(self, 0, index)
def pop(self):
return heappop(self)
class SlowHeap(list):
def __init__(self, values):
super().__init__(values)
heapify(self)
self._broken = False
# Solution 2 and 3) repair using sort/heapify in a lazy way:
def update(self, index, value):
super().__setitem__(index, value)
self._broken = True
def __getitem__(self, index):
if self._broken:
self._repair()
self._broken = False
return super().__getitem__(index)
def _repair(self):
...
def pop(self):
if self._broken:
self._repair()
return heappop(self)
class HeapifyHeap(SlowHeap):
def _repair(self):
heapify(self)
class SortHeap(SlowHeap):
def _repair(self):
self.sort()
def rand_update(heap):
index = random.randint(0, len(heap)-1)
new_value = random.randint(max_int+1, max_int*2)
heap.update(index, new_value)
def rand_updates(update_count, heap):
for i in range(update_count):
rand_update(heap)
heap[0]
def verify(heap):
last = None
while heap:
item = heap.pop()
if last is not None and item < last:
raise RuntimeError(f"{item} was smaller than last {last}")
last = item
def run_perf_test(update_count, data, heap_class):
test_heap = heap_class(data)
t0 = time.time()
rand_updates(update_count, test_heap)
perf = (time.time() - t0)*1e3
verify(test_heap)
return perf
results = []
max_int = 500
update_count = 100
for i in range(2, 7):
test_size = 10**i
test_data = [random.randint(0, max_int) for _ in range(test_size)]
perf = run_perf_test(update_count, test_data, UpdateHeap)
results.append((test_size, "update", perf))
perf = run_perf_test(update_count, test_data, HeapifyHeap)
results.append((test_size, "heapify", perf))
perf = run_perf_test(update_count, test_data, SortHeap)
results.append((test_size, "sort", perf))
import pandas as pd
import seaborn as sns
dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)
sns.lineplot(
data=dtf,
x="heap size",
y="duration (ms)",
hue="method",
)
结果
如您所见,使用 _siftdown()
和 _siftup()
的“更新”方法渐进地更快。
您应该知道您的代码是做什么的,以及需要多长时间才能 运行。如果有疑问,您应该检查一下。 @cglaced 检查了执行需要多长时间,但他没有质疑应该需要多长时间。如果他这样做了,他会发现两者不匹配。其他人也为之倾倒。
heap size method duration (ms)
0 100 update 0.219107
1 100 heapify 0.412703
2 100 sort 0.242710
3 1000 update 0.198841
4 1000 heapify 2.947330
5 1000 sort 0.605345
6 10000 update 0.203848
7 10000 heapify 32.759190
8 10000 sort 4.621506
9 100000 update 0.348568
10 100000 heapify 327.646971
11 100000 sort 49.481153
12 1000000 update 0.256062
13 1000000 heapify 3475.244761
14 1000000 sort 1106.570005
处理私有函数
But _siftup and _siftdown are protected member in heapq, so they are not recommended to access from outside.
代码片段很短,因此您可以在将它们重命名为 public 函数后将它们包含在您自己的代码中:
def siftdown(heap, startpos, pos):
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
def siftup(heap, pos):
endpos = len(heap)
startpos = pos
newitem = heap[pos]
# Bubble up the smaller child until hitting a leaf.
childpos = 2*pos + 1 # leftmost child position
while childpos < endpos:
# Set childpos to index of smaller child.
rightpos = childpos + 1
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
childpos = rightpos
# Move the smaller child up.
heap[pos] = heap[childpos]
pos = childpos
childpos = 2*pos + 1
# The leaf at pos is empty now. Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
heap[pos] = newitem
siftdown(heap, startpos, pos)
保持堆不变性
How to restore the heap invariant, when one element is out-of-order?
使用高级堆 API,您可以进行一系列计数更新,然后 运行 heapify(),然后再进行更多堆操作。这可能不足以满足您的需求。
也就是说,heapify()
函数非常 快。有趣的是,list.sort
方法更加优化,对于某些类型的输入可能优于 heapify()
。
更好的数据结构
need to count the frequency of words, and maintain the top k max-count words, which prepare to output at any moment. So I use heap here. When one word count++, I need update it if it is in heap.
考虑使用与堆不同的数据结构。起初,堆似乎很适合这个任务,但是在堆中找到任意条目很慢,即使我们可以用 siftup/siftdown.
快速更新它相反,考虑保留从单词到单词列表和计数列表中的位置的字典映射。保持这些列表排序只需要在计数增加时交换位置:
from bisect import bisect_left
word2pos = {}
words = [] # ordered by descending frequency
counts = [] # negated to put most common first
def tally(word):
if word not in word2pos:
word2pos[word] = len(word2pos)
counts.append(-1)
words.append(word)
else:
pos = word2pos[word]
count = counts[pos]
swappos = bisect_left(counts, count, hi=pos)
words[pos] = swapword = words[swappos]
counts[pos] = counts[swappos]
word2pos[swapword] = pos
words[swappos] = word
counts[swappos] = count - 1
word2pos[word] = swappos
def topwords(n):
return [(-counts[i], words[i]) for i in range(n)]
内置解决方案
还有另一个“out-of-the-box”解决方案可能会满足您的需要。只需使用 collections.Counter():
>>> from collections import Counter
>>> c = Counter()
>>> for word in 'one two one three two three three'.split():
... c[word] += 1
...
>>> c.most_common(2)
[('three', 3), ('one', 2)]
二叉树或排序容器解决方案
另一种方法是使用二叉树或排序容器。这些具有 O(log n) 插入和删除。他们随时准备以正向或反向顺序迭代,无需额外的计算工作。
这是一个使用 Grant Jenk 的精彩 Sorted Containers 包的解决方案:
from sortedcontainers import SortedSet
from dataclasses import dataclass, field
from itertools import islice
@dataclass(order=True, unsafe_hash=True, slots=True)
class Entry:
count: int = field(hash=False)
word: str
w2e = {} # type: Dict[str, Entry]
ss = SortedSet() # type: Set[Entry]
def tally(word):
if word not in w2e:
entry = w2e[word] = Entry(1, word)
ss.add(entry)
else:
entry = w2e[word]
ss.remove(entry)
entry.count += 1
ss.add(entry)
def topwords(n):
return list(islice(reversed(ss), n)