实施中线维护

Implementation of Median Maintenance

我正在尝试解决我正在参加的在线课程中的问题,但我相信我被卡住了。

这就是问题

The goal of this problem is to implement the "Median Maintenance" algorithm. The text file contains a list of the integers from 1 to 10000 in unsorted order; you should treat this as a stream of numbers, arriving one by one. Letting xi denote the ith number of the file, the kth median mk is defined as the median of the numbers x1,…,xk. (So, if k is odd, then mk is ((k+1)/2)th smallest number among x1,…,xk; if k is even, then mk is the (k/2)th smallest number among x1,…,xk.)

Find the sum of the 1000 medians.

下面是我的代码,它输出了错误的答案,我似乎无法弄清楚出了什么问题

import heapq
# all_ints = list(map(int, open("stanford_algo/course_2_graph_search/median.txt").read().splitlines()))
all_ints = [6331, 2793, 1640, 9290, 225, 625, 6195, 2303, 5685, 1354]
min_heap_elements =  [all_ints[0]] # has all elements more than median
max_heap_elements =  [all_ints[1]] # has all elements less than median
heapq.heapify(min_heap_elements) # has all elements more than median
heapq._heapify_max(max_heap_elements) # has all elements less than median
medians = []
medians.append(all_ints[0])
medians.append(all_ints[1]) #doing this because I can see the first two elements are in decreasing order

for i, next_int in enumerate(all_ints[2:],start=3):
    if next_int > min(min_heap_elements):
        heapq.heappush(min_heap_elements, next_int)
        heapq.heapify(min_heap_elements)
    elif next_int <=  max(max_heap_elements):
        max_heap_elements.append(next_int)
        heapq._heapify_max(max_heap_elements)
    else:
        if len(min_heap_elements) > len(max_heap_elements):
            max_heap_elements.append(next_int)
            heapq._heapify_max(max_heap_elements)
        else:
            heapq.heappush(min_heap_elements, next_int)
            heapq.heapify(min_heap_elements)
    if len(max_heap_elements) - len(min_heap_elements) > 1:
        extract = max_heap_elements.pop(0)
        heapq.heappush(min_heap_elements, extract)
        heapq._heapify_max(max_heap_elements)
        heapq.heapify(min_heap_elements)
    elif len(min_heap_elements) - len(max_heap_elements) > 1:
        extract = min_heap_elements.pop(0)
        max_heap_elements.append(extract)
        heapq._heapify_max(max_heap_elements)
        heapq.heapify(min_heap_elements)
    median = [max(max_heap_elements), min(min_heap_elements)][(i)%2]
    medians.append(median)

sum(medians)%10000 # should be 9335

我在这里使用了两个堆。一个用于将大于媒体的元素存储在最小堆(min_heap_elements)中,另一个堆(max_heap_elements)用于存储小于中位数的元素。对于每个新元素,如果它小于(或等于)最大堆的最大元素,我将它添加到 max_heap_elements。我

如果新元素大于最小堆的最小元素,我将其添加到min_heap_elements。如果这两种情况都不是,我会查看哪个堆更短并将其添加到那个堆中。

但是,我正在这里做一些事情,我不能把我的手指放在上面。

编辑:

这些是我得到的中位数

>>> medians
[6331, 2793, 6331, 2793, 6331, 1640, 2793, 2303, 2793, 2303]

这就是我所期待的

>>> correct_medians
[6331, 2793, 2793, 2793, 2793, 1640, 2793, 2303, 2793, 2303]

问题在于如何计算两个堆的中位数,因为当索引为奇数时,不能保证左边的元素比右边的元素多一个。

你应该这样做

if len(max_heap_elements) == len(min_heap_elements):
    median = max(max_heap_elements)
elif len(max_heap_elements) > len(min_heap_elements):
    median = max(max_heap_elements)
else:
    median = min(min_heap_elements)

此外,请注意,如果您正在使用堆,是因为您想要实现 O(nlogn) 解决方案,但是,通过重复调用 heapifymaxmin, 你不会得到想要的时间复杂度。

而不是 min(min_heap_elements)min_heap_elements[0],删除 heappush 之后的 heapify 调用,而不是列表的 pop 使用 heappop

最后,对于最大堆,您可以得到一个包含取反值的列表,因为 heapq 模块不支持最大堆,它们仅 "support" 一些操作,例如 _heappop_max,但是没有 _heappush_max,所以你总是需要调用 _heapify_max

编辑: 如果不需要时间复杂度,您可以只使用标准库中的函数 statistics.median_low

我正在学习相同的课程,所以我的解决方案在这里。

import heapq
import statistics

with open('home_work_week3_Median.txt', 'r') as file:
    line = file.read().strip().split('\n')
file.close()
assert len(line) == 10000

# solution with heapq
nn = []
num = 0

for n, v in enumerate(line):

    nn.append(v) 

#     index = (len(nn) - 1) // 2 # here we get the index of new list
#     new = heapq.nsmallest(n+1, nn) # get the smallest num of heapqed list
#     print(new[index], new, index)

    # here we combine it into one line
    num += int(heapq.nsmallest(n+1, nn)[(len(nn) - 1) // 2])

num%10000


# solution with statistics lib
num = 0

for index, number in enumerate(test):
    num += int(statistics.median_low(test[:index+1]))

num%10000

我不知道我的解决方案是否正确,两种解决方案的执行时间相同(大约 11 秒)。我相信它会更好