如何避免在 heapq 中使用 _siftup 或 _siftdown

how to avoid using _siftup or _siftdown in heapq

我不知道如何在不使用 _siftup_siftdown 的情况下有效地解决以下问题:

当一个元素乱序时,如何恢复堆不变性?

换句话说,将 heap 中的 old_value 更新为 new_value,并保持 heap 正常工作。您可以假设堆中只有一个 old_value。函数定义如下:

def update_value_in_heap(heap, old_value, new_value):

这是我的真实场景,有兴趣的可以看看

那么当一个特定的字数增加时如何更新堆?

这是 _siftup 或 _siftdown 版本的简单示例(不是我的场景):

>>> from heapq import _siftup, _siftdown, heapify, heappop

>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 22              # increase the 8 to 22
>>> i = data.index(old)
>>> data[i] = new
>>> _siftup(data, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 5, 7, 10, 18, 19, 22, 37]

>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 4              # decrease the 8 to 4
>>> i = data.index(old)
>>> data[i] = new
>>> _siftdown(data, 0, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 4, 5, 7, 10, 18, 19, 37]

索引成本为 O(n),更新成本为 O(logn)。 heapify 是另一种解决方案,但是 效率低于 _siftup_siftdown.

但是_siftup_siftdown是heapq中受保护的成员,所以不建议从外部访问。

那么有没有更好更高效的方法来解决这个问题呢?这种情况的最佳做法?

感谢阅读,非常感谢它能帮助我。 :)

已参考 ,但没有解决我的问题

TL;DR 使用 heapify.

您必须牢记的一件重要事情是,理论复杂性和性能是两个不同的事物(即使它们是相关的)。换句话说,实施也很重要。渐近复杂性给你一些 下限 你可以将其视为保证,例如 O(n) 中的算法确保在最坏的情况下,你将执行许多指令与输入大小成线性关系。这里有两件重要的事情:

  1. 常数被忽略,但常数在现实生活中很重要;
  2. 最坏的情况取决于您考虑的算法,而不仅仅是输入。

根据您考虑的topic/problem,第一点可能非常重要。在某些领域,隐藏在渐近复杂性中的常量是如此之大,以至于您甚至无法构建比常量更大的输入(或者考虑该输入是不现实的)。这不是这里的情况,但这是你必须始终牢记的事情。

鉴于这两个观察结果,您不能真的说:实现 B 比 A 快,因为 A 是从 O(n) 算法派生的,而 B 是从 O(log n) 算法派生的算法。即使这是一个很好的开端一般来说,它并不总是足够的。当所有输入都同样可能发生时,理论复杂性对于比较算法特别有用。换句话说,当你的算法非常通用时。

如果您知道您的用例和输入是什么,您可以直接测试性能。使用测试和渐近复杂度将使您对算法的执行方式有一个很好的了解(在极端情况和任意实际情况下)。

也就是说,让 运行 对以下 class 进行一些性能测试,这些 class 将实现 (there are actually four strategies here, but Invalidate and Reinsert doesn't seem right in your case as you'll invalidate each item as many time as you see a given word). I'll include most of my code so you can double check that I haven't messed up (you can even check the complete notebook):

from heapq import _siftup, _siftdown, heapify, heappop

class Heap(list):
  def __init__(self, values, sort=False, heap=False):
    super().__init__(values)
    heapify(self)
    self._broken = False
    self.sort = sort
    self.heap = heap or not sort

  # Solution 1) repair using the knowledge we have after every update:        
  def update(self, key, value):
    old, self[key] = self[key], value
    if value > old:
        _siftup(self, key)
    else:
        _siftdown(self, 0, key)
    
  # Solution 2 and 3) repair using sort/heapify in a lazzy way:
  def __setitem__(self, key, value):
    super().__setitem__(key, value)
    self._broken = True
    
  def __getitem__(self, key):
    if self._broken:
        self._repair()
        self._broken = False
    return super().__getitem__(key)

  def _repair(self):  
    if self.sort:
        self.sort()
    elif self.heap:
        heapify(self)

  # … you'll also need to delegate all other heap functions, for example:
  def pop(self):
    self._repair()
    return heappop(self)

我们可以先检查三种方法是否都有效:

data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]

heap = Heap(data[:])
heap.update(8, 22)
heap.update(7, 4)
print(heap)

heap = Heap(data[:], sort_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)

heap = Heap(data[:], heap_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)

然后我们可以运行使用以下函数进行一些性能测试:

import time
import random

def rand_update(heap, lazzy_fix=False, **kwargs):
    index = random.randint(0, len(heap)-1)
    new_value = random.randint(max_int+1, max_int*2)
    if lazzy_fix:
        heap[index] = new_value
    else:
        heap.update(index, new_value)
    
def rand_updates(n, heap, lazzy_fix=False, **kwargs):
    for _ in range(n):
        rand_update(heap, lazzy_fix)
        
def run_perf_test(n, data, **kwargs):
    test_heap = Heap(data[:], **kwargs)
    t0 = time.time()
    rand_updates(n, test_heap, **kwargs)
    test_heap[0]
    return (time.time() - t0)*1e3

results = []
max_int = 500
nb_updates = 1

for i in range(3, 7):
    test_size = 10**i
    test_data = [random.randint(0, max_int) for _ in range(test_size)]

    perf = run_perf_test(nb_updates, test_data)
    results.append((test_size, "update", perf))
    
    perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, heap_fix=True)
    results.append((test_size, "heapify", perf))

    perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, sort_fix=True)
    results.append((test_size, "sort", perf))

结果如下:

import pandas as pd
import seaborn as sns

dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)

sns.lineplot(
    data=dtf, 
    x="heap size", 
    y="duration (ms)", 
    hue="method",
)

从这些测试中我们可以看出 heapify 似乎是最合理的选择,它在最坏的情况下具有不错的复杂度:O(n),并且在实践中表现更好。另一方面,研究其他选项可能是个好主意(比如拥有一个专用于该特定问题的数据结构,例如使用箱子将单词放入其中,然后将它们从一个箱子移动到下一个看起来像是一个可能的轨道调查)。

重要说明:这种情况(更新与读取比率 1:1)对 heapifysort 解决方案都是不利的。所以如果你设法有一个k:1的比例,这个结论就更清楚了(你可以把上面代码中的nb_updates = 1换成nb_updates = k)。

数据框详细信息:

    heap size   method  duration in ms
0        1000   update        0.435114
1        1000  heapify        0.073195
2        1000     sort        0.101089
3       10000   update        1.668930
4       10000  heapify        0.480175
5       10000     sort        1.151085
6      100000   update       13.194084
7      100000  heapify        4.875898
8      100000     sort       11.922121
9     1000000   update      153.587103
10    1000000  heapify       51.237106
11    1000000     sort      145.306110

@cglacet 的回答是完全错误的,但看起来很合法。 他提供的代码片段完全坏了!它也很难阅读。 _siftup()heapify() 中被调用 n//2 次,因此它本身不能比 _siftup() 快。

回答原问题,没有更好的办法。如果您担心这些方法是私有的,请创建您自己的方法来做同样的事情。

我唯一同意的是,如果长时间不需要从堆中读取,可能有利于懒惰heapify() 一旦你需要它们。问题是你是否应该为此使用堆。

让我们用他的代码片段来解决问题:

heapify() 函数被多次调用以进行“更新”运行。 导致这种情况的错误链如下:

  • 他通过了heap_fix,但预计heap sort
  • 也是如此
  • 如果 self.sort 总是 False,那么 self.heap 总是 True
  • 他重新定义了__getitem__()__setitem__(),每次_siftdown()_siftup()赋值或读取某些东西时都会调用它们(注意:这两个在C中没有调用,所以他们使用 __getitem__()__setitem__())
  • 如果self.heapTrue并且正在调用__getitem__()__setitem__(),则每次_siftup()或调用_repair()函数siftdown() 交换元素。但是对 heapify() 的调用是在 C 中完成的,所以 __getitem__() 不会被调用,也不会陷入无限循环
  • 他重新定义了 self.sort 所以调用它,就像他尝试做的那样,会失败
  • 他读过一次,但更新了一个项目 nb_updates 次,而不是他声称的 1:1

我修正了这个例子,我试着尽可能地验证它,但我们都会犯错误。有空自己查一下。

代码

import time
import random

from heapq import _siftup, _siftdown, heapify, heappop

class UpdateHeap(list):
    def __init__(self, values):
        super().__init__(values)
        heapify(self)

    def update(self, index, value):
        old, self[index] = self[index], value
        if value > old:
            _siftup(self, index)
        else:
            _siftdown(self, 0, index)

    def pop(self):
        return heappop(self)

class SlowHeap(list):
    def __init__(self, values):
        super().__init__(values)
        heapify(self)
        self._broken = False
        
    # Solution 2 and 3) repair using sort/heapify in a lazy way:
    def update(self, index, value):
        super().__setitem__(index, value)
        self._broken = True
    
    def __getitem__(self, index):
        if self._broken:
            self._repair()
            self._broken = False
        return super().__getitem__(index)

    def _repair(self):
        ...

    def pop(self):
        if self._broken:
            self._repair()
        return heappop(self)

class HeapifyHeap(SlowHeap):

    def _repair(self):
        heapify(self)


class SortHeap(SlowHeap):

    def _repair(self):
        self.sort()

def rand_update(heap):
    index = random.randint(0, len(heap)-1)
    new_value = random.randint(max_int+1, max_int*2)
    heap.update(index, new_value)
    
def rand_updates(update_count, heap):
    for i in range(update_count):
        rand_update(heap)
        heap[0]
        
def verify(heap):
    last = None
    while heap:
        item = heap.pop()
        if last is not None and item < last:
            raise RuntimeError(f"{item} was smaller than last {last}")
        last = item

def run_perf_test(update_count, data, heap_class):
    test_heap = heap_class(data)
    t0 = time.time()
    rand_updates(update_count, test_heap)
    perf = (time.time() - t0)*1e3
    verify(test_heap)
    return perf


results = []
max_int = 500
update_count = 100

for i in range(2, 7):
    test_size = 10**i
    test_data = [random.randint(0, max_int) for _ in range(test_size)]

    perf = run_perf_test(update_count, test_data, UpdateHeap)
    results.append((test_size, "update", perf))
    
    perf = run_perf_test(update_count, test_data, HeapifyHeap)
    results.append((test_size, "heapify", perf))

    perf = run_perf_test(update_count, test_data, SortHeap)
    results.append((test_size, "sort", perf))

import pandas as pd
import seaborn as sns

dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)

sns.lineplot(
    data=dtf, 
    x="heap size", 
    y="duration (ms)", 
    hue="method",
)

结果

如您所见,使用 _siftdown()_siftup() 的“更新”方法渐进地更快。

您应该知道您的代码是做什么的,以及需要多长时间才能 运行。如果有疑问,您应该检查一下。 @cglaced 检查了执行需要多长时间,但他没有质疑应该需要多长时间。如果他这样做了,他会发现两者不匹配。其他人也为之倾倒。

    heap size   method  duration (ms)
0         100   update       0.219107
1         100  heapify       0.412703
2         100     sort       0.242710
3        1000   update       0.198841
4        1000  heapify       2.947330
5        1000     sort       0.605345
6       10000   update       0.203848
7       10000  heapify      32.759190
8       10000     sort       4.621506
9      100000   update       0.348568
10     100000  heapify     327.646971
11     100000     sort      49.481153
12    1000000   update       0.256062
13    1000000  heapify    3475.244761
14    1000000     sort    1106.570005

处理私有函数

But _siftup and _siftdown are protected member in heapq, so they are not recommended to access from outside.

代码片段很短,因此您可以在将它们重命名为 public 函数后将它们包含在您自己的代码中:

def siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem

def siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2*pos + 1    # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        # Move the smaller child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    siftdown(heap, startpos, pos)

保持堆不变性

How to restore the heap invariant, when one element is out-of-order?

使用高级堆 API,您可以进行一系列计数更新,然后 运行 heapify(),然后再进行更多堆操作。这可能不足以满足您的需求。

也就是说,heapify() 函数非常 快。有趣的是,list.sort 方法更加优化,对于某些类型的输入可能优于 heapify()

更好的数据结构

need to count the frequency of words, and maintain the top k max-count words, which prepare to output at any moment. So I use heap here. When one word count++, I need update it if it is in heap.

考虑使用与堆不同的数据结构。起初,堆似乎很适合这个任务,但是在堆中找到任意条目很慢,即使我们可以用 siftup/siftdown.

快速更新它

相反,考虑保留从单词到单词列表和计数列表中的位置的字典映射。保持这些列表排序只需要在计数增加时交换位置:

from bisect import bisect_left

word2pos = {}
words = []    # ordered by descending frequency
counts = []   # negated to put most common first

def tally(word):
    if word not in word2pos:
        word2pos[word] = len(word2pos)
        counts.append(-1)
        words.append(word)
    else:
        pos = word2pos[word]
        count = counts[pos]
        swappos = bisect_left(counts, count, hi=pos)
        words[pos] = swapword = words[swappos]
        counts[pos] = counts[swappos]
        word2pos[swapword] = pos
        words[swappos] = word
        counts[swappos] = count - 1
        word2pos[word] = swappos

def topwords(n):
    return [(-counts[i], words[i]) for i in range(n)]

内置解决方案

还有另一个“out-of-the-box”解决方案可能会满足您的需要。只需使用 collections.Counter():

>>> from collections import Counter
>>> c = Counter()
>>> for word in 'one two one three two three three'.split():
...     c[word] += 1
...
>>> c.most_common(2)
[('three', 3), ('one', 2)]

二叉树或排序容器解决方案

另一种方法是使用二叉树或排序容器。这些具有 O(log n) 插入和删除。他们随时准备以正向或反向顺序迭代,无需额外的计算工作。

这是一个使用 Grant Jenk 的精彩 Sorted Containers 包的解决方案:

from sortedcontainers import SortedSet
from dataclasses import dataclass, field
from itertools import islice

@dataclass(order=True, unsafe_hash=True, slots=True)
class Entry:
    count: int = field(hash=False)
    word: str

w2e = {}                  # type: Dict[str, Entry]
ss = SortedSet()          # type: Set[Entry]   

def tally(word):
    if word not in w2e:
        entry = w2e[word] = Entry(1, word)
        ss.add(entry)
    else:
        entry = w2e[word]
        ss.remove(entry)
        entry.count += 1
        ss.add(entry)

def topwords(n):
    return list(islice(reversed(ss), n)