在没有for循环的情况下在批处理数据的每一帧中找到子数组第一次出现的索引的最佳方法

Optimal way to find index of first occurrence of subarray in each frame of batch data without for loop

我必须在每个 frame.The 大小为 (batch_size,400) 的数据中找到子数组第一次出现的索引。我需要在每个大小为 400 的帧中找到三个连续出现的索引。 数据-> [0 0 0 1 1 1 0 1 1 1 1 1][0 0 0 0 1 1 1 0 0 1 1 1] [0 1 1 1 0 0 0 1 1 1 1 1]

输出应该是[3 4 1]

本机解决方案是使用 for 循环,但由于数据很大,因此非常耗时。

numpytensorflow 中快速高效的任何实现

对此没有简单的 numpy 解决方案。但是,如果您确实需要快速执行以下操作,则可以使用 numba:

函数 find_first 基本上完成了您将使用 for 循环执行的操作。但是由于您使用的是 numba,因此会编译该方法,因此速度要快得多。 然后,您只需使用 np.apply_along_axis:

将该方法应用于每个批次
import numpy as np
from numba import jit


@jit(nopython=True)
def find_first(seq, arr):
    """return the index of the first occurence of item in arr"""
    for i in range(len(arr)-2):
        if np.all(seq == arr[i:i+3]):
            return i
    return -1

# construct test array
test = np.round(np.random.random((64,400)))

# this will give you the array of indices
np.apply_along_axis(lambda m: find_first(np.array([1,1,1]), m), axis=1, arr = test)

我修改了this answer

的方法