以 numpy 中的矢量化方式:根据块内的内容有条件地剪切数据块(不同大小)

In a vectorized way in numpy: Cutting out chunks of data (of different sizes) conditionally dependent on whats inside the chunk

我有一个整数向量(一维 numpy 数组),如下所示:

8, 1, 1, 2, 8, 99, 1, 2, 1, 2, 8, 2, 2, 2, 8, 99, 99, 8, 1, 1

(以矢量化的方式,)我想过滤掉所有包含至少一个 99 值的 8 之间的数据。

所以在这个例子中,我要截取的数据以粗体列出:

8, 1, 1, 2, 8, 99, 1, 2, 1, 2, 8, 2, 2, 2, 8, 99, 99, 8, 1, 1

(即,它是介于 8 之间且至少包含一个 99 的数据)

所以如果我制作一个布尔掩码来裁剪这些数据,它看起来像:

Data: 8, 1, 1, 2, 8, 99, 1, 2, 1, 2, 8, 2, 2, 2, 8, 99, 99, 8, 1, 1
Mask: T, T, T, T, F,  F, F, F, F, F, T, T, T, T, F,  F,  F, T, T, T

裁剪后的数据如下:

Data(Mask) = 8, 1, 1, 2, 8, 2, 2, 2, 8, 1, 1

如果保证 8 之间的间距相等,我可以想出可以执行此操作的矢量化代码。这是代码:

inputRaw = np.array([8, 2, 3, 2, 99, 2, 8, 2, 3, 2, 2, 2, 8, 2, 3, 3, 3, 3])
inputPartitioned = np.reshape(inputRaw, (3, 6))
# reshaping it into an array of the form: np.array([[8, 2, 3, 2, 99, 2], [8, 2, 3, 2, 2, 2], [8, 2, 3, 3, 3, 3]])
selectedSections = np.logical_not(np.any(inputPartitioned>8, axis=1))
outputPartitioned = inputPartitioned[selectedSections]
outputFlattened = outputPartitioned.flatten()

我需要的另一件事是一个遮罩或索引,它告诉我(在原始索引中)被裁剪的数据。 (我需要这个,因为我有第二个数组,我想跟踪它与第一个数组共享索引)。我可以像这样编写这个掩码(假设 8 之间的间距相等):

inputIndex = np.arange(inputRaw.size)
inputIndexPartitioned =  np.reshape(inputIndex, (3, 6))
outputPartitionedIndex = inputIndexPartitioned[selectedSections]
outputFlattenedIndex = outputPartitionedIndex.flatten()

但我不确定在 8 之间的间距不相等的情况下如何以矢量化方式执行此操作。

有什么想法吗?这些阵列非常长,因此适用于大型阵列的快速解决方案很有帮助。另外,我相当有信心这些“99”总是紧跟在 8 之后,所以这可能对制定算法有帮助。

这应该可以解决问题(我已经避免了 "fairly confident" 部分,所以它会 运行 即使 99 不是紧跟在 8 之后)。

import numpy as np

in_arr = np.array([8, 1, 1, 2, 8, 99, 1, 2, 1, 2, 8, 2, 2, 2, 8, 99, 99, 8, 1, 1])

mask_8 = in_arr == 8
mask_8_cumsum = np.cumsum(mask_8)
print(mask_8_cumsum)
>>> [1 1 1 1 2 2 2 2 2 2 3 3 3 3 4 4 4 5 5 5]

unique_inds = np.unique(mask_8_cumsum[in_arr == 99])
print(unique_inds)
>>> [2 4]

final_mask = ~np.isin(mask_8_cumsum, unique_inds)
final_data = in_arr[final_mask]
print(final_mask)
>>> [ True  True  True  True False False False False False False  True  True
  True  True False False False  True  True  True]

print(final_data)
>>> [8 1 1 2 8 2 2 2 8 1 1]

方法 #1

这是一个通用案例的矢量化案例,当 99's 出现在两个 8's 之间并且这两个 8's 之间的所有元素都将是 removed/masked-out -

def vectorized1_app(a):
    m1 = a==8
    m2 = a==99
    d = m1.cumsum()
    occ_mask = np.bincount(d,m2)<1
    if m1.argmax() > m2.argmax():
        occ_mask[0] = ~occ_mask[0]

    if m1[::-1].argmax() > m2[::-1].argmax():
        occ_mask[-1] = ~occ_mask[-1]
    mask = occ_mask[d]
    return mask

方法 #2

对于 8's99's 的特定情况,我们还可以使用 JIT 编译的 numba 代码 -

from numba import njit

@njit
def numba1(a, mask_out):
    N = len(a)
    fill = False
    last8_index = 0
    for i in range(N-1):
        if a[i]==8:
            if a[i+1]==99:
                fill = True
            else:
                fill = False
            last8_index = i

        if fill:        
            mask_out[i] = False

    return mask_out, last8_index

def numba1_app(a):
    N = len(a)
    mask = np.ones(N, dtype=np.bool)
    mask, last8_index = numba1(a, mask)
    if a[-1]!=8:
        mask[last8_index:] = True
    return mask

方法 #2-B

一些边际性能。通过将最后一个元素检查的最后一步推到 numba-part 来提升,就像这样 -

@njit
def numba2(a, mask_out):
    N = len(a)
    fill = False
    last8_index = 0
    for i in range(N-1):
        if a[i]==8:
            if a[i+1]==99:
                fill = True
            else:
                fill = False
            last8_index = i

        if fill:        
            mask_out[i] = False

    if a[N-1]!=8:
        for i in range(last8_index,N-1):
            mask_out[i] = True

    return mask_out

def numba2_app(a):
    return numba2(a, np.ones(len(a), dtype=np.bool))

请注意,已发布方法的输出是掩码,因此用这些掩码输入数组会给我们提供与 Data(Mask).

相当的结果

特殊情况:屏蔽直到第一个 8,在最后一个 8

之后保留

我们可以通过两种方式修改 app#1 -

应用#1-Mod#1-

m1 = a==8
m2 = a==99
d = m1.cumsum()
occ_mask = np.bincount(d,m2)<1
occ_mask[0] = False
mask = occ_mask[d]

如果您必须修改大小写以便您也希望在最后一个 8 之后进行屏蔽,只需执行以下操作:occ_mask[-1] = False.

应用#1-Mod#2-

m1 = a==8
m2 = a==99
d = m1.cumsum().clip(min=1)
occ_mask = np.bincount(d,m2)<1
mask = occ_mask[d]

如果您必须修改大小写以便您也希望在最后一个 8 之后进行屏蔽,请执行:m1c = m1.cumsum(); d = m1c.clip(min=1, max=m1c.max()-1).

另一种 numpy 方法:

def pp(a):                                            
    m8 = a==8
    m99 = a==99
    m = m8|m99
    i = m.nonzero()[0]
    c8 = m8[i]
    i = i[c8]
    n8 = np.count_nonzero(c8)
    if n8 == 0:
        return np.ones(a.size,bool)
    if c8[-1]:
        d8 = np.empty(n8,bool)
        d8[-1] = False
        d8[:-1] = ~c8[1:][c8[:-1]]
    else:
        d8 = ~c8[1:][c8[:-1]]
    d8[1:]^=d8[:-1]
    m8[i] = d8
    m8[0]^=True
    return np.bitwise_xor.accumulate(m8)

例如:

a = np.array([8, 1, 1, 2, 8, 99, 1, 2, 1, 2, 8, 2, 2, 2, 8, 99, 99, 8, 1, 1])
a[pp(a)]
# array([8, 1, 1, 2, 8, 2, 2, 2, 8, 1, 1])