如何优化多维数组搜索例程(每行一个条件)

How to optimize a mutli-dimensional array search routine (one condition per row)

给定一个多维数据数组,我想确定每列满足每个行条件的列。我有一个工作算法,我想进一步优化。虽然我的方法可以接受多个条件,但代码审查中通过 post 建议的方法不能。我想采用代码审查中建议的方法来采用多种条件。

例如,考虑一些样本数据。

import numpy as np

def get_sample_data(nsample):
    """ This function returns a multi-dimensional integer array. """
    if nsample == 0:
        row_a = np.array([1, 4, 7, 3, 10, 3, 5, 1])
        row_b = np.array([2, 5, 30, 30, 10, 5, 5, 1])
        row_c = np.array([23, 21, 22, 23, 23, 25, 21, 23])
    elif nsample == 1:
        row_a = np.linspace(1, 10, 10)
        row_b = row_a * 10
        row_c = row_a + 20
    data = np.array([row_a, row_b, row_c])
    return data

data = get_sample_data(0)
# data = get_sample_data(1)

我编写了一个函数来帮助简化为每行分配各种条件之一的过程。

def search(condition, value, relation):
    """ This function returns the indices at which the array condition is satisfied. """
    if relation in ('equality', 'exact match'):
        res = np.where(condition == value)
    elif relation == 'not equal':
        res = np.where(condition != value)
    elif relation == 'greater than':
        res = np.where(condition > value)
    elif relation in ('less than', 'lesser than'):
        res = np.where(condition < value)
    elif relation == 'greater than or equal':
        res = np.where(condition >= value)
    elif relation in ('less than or equal', 'lesser than or equal'):
        res = np.where(condition <= value)
    elif relation == 'nearest':
        delta = np.abs(condition - value)
        res = np.where(delta == np.min(delta))
    elif relation == 'nearest forward':
        delta = condition - value
        try:
            res = np.where(delta == np.min(delta[delta >= 0]))
        except:
            raise ValueError("no forward-nearest match exists")
    elif relation == 'nearest backward':
        delta = value - condition
        try:
            res = np.where(delta == np.min(delta[delta >= 0]))
        except:
            raise ValueError("no backward-nearest match exists")
    elif relation == 'custom':
        res = np.where(condition)
    else:
        raise ValueError("the input search relation is invalid")
    return res

下面是我的实现,运行成功。

def get_base(shape, value, dtype=int):
    """ This function returns a basemask, over which values may be overwritten. """
    if isinstance(value, (float, int)):
        res = np.ones(shape, dtype=dtype) * value
    elif isinstance(value, str):
        res = np.array([value for idx in range(np.prod(shape))]).reshape(shape)
    return res

def alternate_base(shape, key):
    """ This function returns one of two basemasks, each consisting of a single broadcast value. """
    if key % 2 == 0:
        value = 0.25
    else:
        value = 0.5
    return get_base(shape, value, dtype=float)

def my_method(ndata, search_value, search_relation):
    """ This method was adapted from a CodeReview and successfully works, but I would like to further optimize it. """
    if isinstance(search_relation, str):
        search_relation = (search_relation, search_relation, search_relation)
    elif len(search_relation) != 3:
        raise ValueError("search_relation should be a string or a collection of three relations")
    print("\nDATA SAMPLE:\n{}\n".format(ndata))
    print("SEARCH VALUE:    {}\nSEARCH RELATION:    {}\n".format(search_value, search_relation))
    bases = np.array([alternate_base(len(ndata.T), idx) for idx in range(len(ndata))])
    locs = np.array([search(condition=ndata[idx], value=search_value[idx], relation=search_relation[idx])[0] for idx in range(len(search_relation))])
    for base, loc in zip(bases, locs):
         base[loc] = 0
    condition = np.sum(bases, axis=0)
    idx_res = search(condition=condition, value=0, relation='equality')[0]
    val_res = np.array([ndata[idx][idx_res] for idx in range(len(ndata))])
    print("RESULTANT INDICES:\n{}\n".format(idx_res))
    print("RESULTANT VALUES:\n{}\n".format(val_res))
    if len(idx_res) == 0:
        raise ValueError("match not found for multiple conditions")
    return idx_res

上述方法略有改动based on this code review。审查中建议的方法如下。但是这个方法只涵盖了严格的相等条件(==)。 是否可以使其适应多种条件?

def martin_fabre_method(ndata, search_value):
    """ """
    print("\nNDATA:\n{}\n".format(ndata))
    print("SEARCH VALUE:    {}\n".format(search_value))
    mask = ndata == [[i] for i in search_value]
    idx_res = mask.all(axis=0)
    if not np.any(idx_res):
        raise ValueError("match not found for multiple conditions")
    val_res = ndata[:, idx_res]
    print("RESULTANT INDICES:\n{}\n".format(idx_res))
    print("RESULTANT VALUES:\n{}\n".format(val_res))
    return idx_res

要运行算法,可以复制粘贴上面的和运行下面的:

# my_method(data, search_value=(7, 30, 22), search_relation='equality')
# my_method(data, search_value=(7, 5, 22), search_relation=('less than', 'equality', 'less than'))
martin_fabre_method(data, search_value=(7, 30, 22))

您可以将我的代码审查答案的第一行替换为如下内容:

def get_mask(data, search_value, comparison):
    comparisons = {
        'equal': '__eq__',
        'equality': '__eq__',
        'exact match': '__eq__',
        'greater than': '__gt__',
        'greater than or equal': '__ge__',
        'less than': '__lt__',
        'less than or equal': '__le__',
        'lesser than': '__lt__',
        'lesser than or equal': '__le__',
        'not equal': '__ne__',
    }
    try:
        comp = getattr(data, comparisons[comparison])
        return comp(search_value)
    except KeyError:
        pass
    if comparison == 'custom':
        return np.where(condition)

    delta = data - search_value
    if comparison == 'nearest':
        delta = np.abs(delta)
    elif comparison == 'nearest forward':
        delta = np.where(delta >= 0, delta, np.inf).min(axis=1, keepdims=True)
        print(min_)
    elif comparison == 'nearest backward':
        delta = -np.where(delta <= 0, delta, -np.inf)
    if (delta == np.inf).all(axis=0).any():
        raise ValueError("no %s match exists for searchvalue %s" % (comparison, repr(search_value)))
#     print(repr(delta))
#     print(min_)
    return delta ==  delta.min(axis=1, keepdims=True)

def martin_fabre_method(ndata, search_value, comparison):
    """ """
    print("\nNDATA:\n{}\n".format(ndata))
    print("SEARCH VALUE:    {}\n".format(search_value))
    mask = get_mask(ndata, search_value, comparison)
    idx_res = mask.all(axis=0)
    if not np.any(idx_res):
        raise ValueError("match not found for multiple conditions")
    val_res = ndata[:, idx_res]
    print("RESULTANT INDICES:\n{}\n".format(idx_res))
    print("RESULTANT VALUES:\n{}\n".format(val_res))
    return idx_res

替代operator

第一部分可以使用operator模块更清楚:

def get_mask(data, search_value, comparison):
    import operator
    comparisons = {
        'equal': operator.eq,
        'equality': operator.eq,
        'exact match': operator.eq,
        'greater than': operator.gt,
        'greater than or equal': operator.ge,
        'less than': operator.lt,
        'less than or equal': operator.le,
        'lesser than': operator.lt,
        'lesser than or equal': operator.le,
        'not equal': operator.ne,
    }
    try:
        return comparisons[comparison](data, search_value)
    ....