如何使用 numpy apply_along_axis() 访问相关轴的索引值

How to access index value of the axis in question with numpy apply_along_axis()

我有一个可以完全矢量化的问题,但我没有足够的 space,所以我正在尝试使用 numpy 的 apply_along_axis() 的一半和一半的解决方案。

(注意:这是一个玩具示例,说明了问题的核心。换句话说,我不是在寻找一个 numpy 或 scipy 函数来完成这里函数的作用 - 它是不是真正的功能,只是一个简单的说明。)

我想做的是找出一种方法来访问每次迭代时传递的轴的索引。

假设我们采用了 4 x 4 矩阵:

    M = np.array(([0,0,1,1], [1,1,0,1], [1,0,1,0], [0,0,1,1]))
    M 
   
   array([[0, 0, 1, 1],
          [1, 1, 0, 1],
          [1, 0, 1, 0],
          [0, 0, 1, 1]])

并且想要计算每一列与其他每一列的成对按位逻辑和,但是为了节省(大量)时间,我们只计算列 i,j,其中 j > i 的索引(这样我们最终得到一个三角矩阵)。

在 pandas 中,我可以使用 apply() 很容易地做到这一点,但它对我的目的来说太慢了。

我知道 scikit-learn 中有成对函数,但请假设这些函数不适合我的目的(我的函数比这个玩具函数更复杂)

如果我要使用 numpy 的 apply_long_axis(),我只能计算出如何比较所有 i,j 和 j,i,而不是之前描述的较小的问题。

这是我的解决方案:

def intersections_np(col, M):
    col = col[:,np.newaxis]
    intersection = (M & col).sum(0)
    return(intersection)

result_np = np.apply_along_axis(intersections_np, arr = M, axis = 0,  M = M)
result_np

array([[2, 1, 1, 1],
       [1, 1, 0, 1],
       [1, 0, 3, 2],
       [1, 1, 2, 3]], dtype=int32)

但我真正想做的是:

def intersections_np(col, M):
    col = col[:,np.newaxis]
    start_index = <index_of_current_column> + 1
    other_cols = M[:,start_index:]
    intersection = (other_cols & col).sum(0)
    <possible padding of the array with nans here>
    return(intersection)

result_np = np.apply_along_axis(intersections_np, arr = M, axis = 0,  M = M)

和return:

result_np

array([[nan, nan, nan, nan],
       [1, nan, nan, nan],
       [1, 0, nan, nan],
       [1, 1, 2, nan]], dtype=int32)

有谁知道这样的事情是否可以做到?

谢谢

这有一个相当不错的 pythonic 解决方案:

np.reshape(np.array([(M.T[i] & M.T[j]).sum(0) if j>i else 0 \
for i in range(len(M.T)) for j in range(len(M.T))]),(M.T).shape).T

M.T的用途是访问列。结果向量被重塑为与转置数组相同的形状。然后将数组转置回原始数组形状并产生所需的输出。

我们来计时一下。

你的基础apply

In [142]: timeit np.apply_along_axis(intersections_np, arr = M, axis = 0,  M = M)                    
158 µs ± 3.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

和等效迭代(技术上可能需要转置,结果是对称的,所以没关系):

In [143]: timeit np.array([intersections_np(M[:,i],M) for i in range(M.shape[1])])                   
65.4 µs ± 1.93 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

和@jfahne 建议:

In [144]: %%timeit  
     ...: np.reshape(np.array([(M.T[i] & M.T[j]).sum(0) if j>i else 0 \ 
     ...: for i in range(len(M.T)) for j in range(len(M.T))]),(M.T).shape).T 
     ...:  
     ...:                                                                                            
95.2 µs ± 2.99 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

请注意 apply 比普通迭代慢。这与我过去的测试一致。 apply 仅在数组为 3d 或更多时才有帮助,并且迭代是 'ugly' 双嵌套。它更漂亮,但仍然没有更快。这是一个方便的工具,而不是一个速度工具。

一个完整的'vectorized'解决方案(带有numpy广播等):

In [148]: (M[:,:,None] & M[:,None,:]).sum(0)                                                         
Out[148]: 
array([[2, 1, 1, 1],
       [1, 1, 0, 1],
       [1, 0, 3, 2],
       [1, 1, 2, 3]])
In [149]: timeit (M[:,:,None] & M[:,None,:]).sum(0)                                                  
14.9 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

它确实生成了一个中间 (4,4,4) 数组,并且不会避免重复,但是因为在 Python 级别没有迭代,所以它非常快。试图将计算限制在下(或上)三角形通常是不值得的。

但是如果你真的想要下三角和速度,可以考虑使用numba。对于迭代问题,它可以非常快(但会牺牲一些灵活性)。


这是你的交叉路口版本,仅限于下三角

In [159]: def foo(M): 
     ...:     m = M.shape[0] 
     ...:     res = np.full((m,m), np.nan) 
     ...:     for i in range(m-1): 
     ...:         temp = (M[:,i,None] & M[:,(i+1):]).sum(0) 
     ...:         res[-temp.shape[0]:,i] = temp 
     ...:     return res 
     ...:      
     ...:                                                                                            
In [160]: foo(M)                                                                                     
Out[160]: 
array([[nan, nan, nan, nan],
       [ 1., nan, nan, nan],
       [ 1.,  0., nan, nan],
       [ 1.,  1.,  2., nan]])
In [161]: timeit foo(M)                                                                              
59.3 µs ± 2.42 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

基本上与我的 [143] 相同的时间 - 它在 & 步骤中的计算较少,但索引较多,因此速度变化很小。