没有显式数组的二进制搜索
Binary-search without an explicit array
我想使用例如np.searchsorted
,但是,我不想创建包含值的显式数组。相反,我想定义一个函数,给出数组所需位置的预期值,例如p(i) = i
,其中 i 表示数组中的位置。
在我的例子中,生成一个关于函数的值数组既不高效也不优雅。有什么办法可以实现吗?
受@norok2 评论的启发,我认为你可以使用这样的东西:
def f(i):
return i*2 # Just an example
class MySeq(Sequence):
def __init__(self, f, maxi):
self.maxi = maxi
self.f = f
def __getitem__(self, x):
if x < 0 or x > self.maxi:
raise IndexError()
return self.f(x)
def __len__(self):
return self.maxi + 1
在这种情况下,f
是您的函数,而 maxi
是最大索引。这当然只有在函数 f
return 值按排序顺序时才有效。
此时你可以在np.searchsorted
.
里面使用一个MySeq
类型的对象
像这样的东西怎么样:
import collections
class GeneratorSequence(collections.Sequence):
def __init__(self, func, size):
self._func = func
self._len = size
def __len__(self):
return self._len
def __getitem__(self, i):
if 0 <= i < self._len:
return self._func(i)
else:
raise IndexError
def __iter__(self):
for i in range(self._len):
yield self[i]
这适用于 np.searchsorted()
,例如:
import numpy as np
gen_seq = GeneratorSequence(lambda x: x ** 2, 100)
np.searchsorted(gen_seq, 9)
# 3
您也可以编写自己的二进制搜索函数,在这种情况下您并不需要 NumPy,它实际上是有益的:
def bin_search(seq, item):
first = 0
last = len(seq) - 1
found = False
while first <= last and not found:
midpoint = (first + last) // 2
if seq[midpoint] == item:
first = midpoint
found = True
else:
if item < seq[midpoint]:
last = midpoint - 1
else:
first = midpoint + 1
return first
给出相同的结果:
all(bin_search(gen_seq, i) == np.searchsorted(gen_seq, i) for i in range(100))
# True
顺便说一句,这也是 WAY 更快:
gen_seq = GeneratorSequence(lambda x: x ** 2, 1000000)
%timeit np.searchsorted(gen_seq, 10000)
# 1 loop, best of 3: 1.23 s per loop
%timeit bin_search(gen_seq, 10000)
# 100000 loops, best of 3: 16.1 µs per loop
我想使用例如np.searchsorted
,但是,我不想创建包含值的显式数组。相反,我想定义一个函数,给出数组所需位置的预期值,例如p(i) = i
,其中 i 表示数组中的位置。
在我的例子中,生成一个关于函数的值数组既不高效也不优雅。有什么办法可以实现吗?
受@norok2 评论的启发,我认为你可以使用这样的东西:
def f(i):
return i*2 # Just an example
class MySeq(Sequence):
def __init__(self, f, maxi):
self.maxi = maxi
self.f = f
def __getitem__(self, x):
if x < 0 or x > self.maxi:
raise IndexError()
return self.f(x)
def __len__(self):
return self.maxi + 1
在这种情况下,f
是您的函数,而 maxi
是最大索引。这当然只有在函数 f
return 值按排序顺序时才有效。
此时你可以在np.searchsorted
.
MySeq
类型的对象
像这样的东西怎么样:
import collections
class GeneratorSequence(collections.Sequence):
def __init__(self, func, size):
self._func = func
self._len = size
def __len__(self):
return self._len
def __getitem__(self, i):
if 0 <= i < self._len:
return self._func(i)
else:
raise IndexError
def __iter__(self):
for i in range(self._len):
yield self[i]
这适用于 np.searchsorted()
,例如:
import numpy as np
gen_seq = GeneratorSequence(lambda x: x ** 2, 100)
np.searchsorted(gen_seq, 9)
# 3
您也可以编写自己的二进制搜索函数,在这种情况下您并不需要 NumPy,它实际上是有益的:
def bin_search(seq, item):
first = 0
last = len(seq) - 1
found = False
while first <= last and not found:
midpoint = (first + last) // 2
if seq[midpoint] == item:
first = midpoint
found = True
else:
if item < seq[midpoint]:
last = midpoint - 1
else:
first = midpoint + 1
return first
给出相同的结果:
all(bin_search(gen_seq, i) == np.searchsorted(gen_seq, i) for i in range(100))
# True
顺便说一句,这也是 WAY 更快:
gen_seq = GeneratorSequence(lambda x: x ** 2, 1000000)
%timeit np.searchsorted(gen_seq, 10000)
# 1 loop, best of 3: 1.23 s per loop
%timeit bin_search(gen_seq, 10000)
# 100000 loops, best of 3: 16.1 µs per loop