对 3d 数组的配对数组进行排序(替换循环)

Sort paired array of 3d array (replace for loop)

我有以下 3d 数组:

import numpy as np

z = np.array([[[10,  2],
               [ 5,  3],
               [ 4,  4]],
              [[ 7,  6],
               [ 4,  2],
               [ 5,  8]]])

我想根据第三个暗淡度和第一个值对它们进行排序。

目前我正在使用以下代码:

from operator import itemgetter

np.array([sorted(x,key=itemgetter(0)) for x in z])
array([[[ 4,  4],
        [ 5,  3],
        [10,  2]],

       [[ 4,  2],
        [ 5,  8],
        [ 7,  6]]])

我想通过删除 for 循环使代码更 efficient/faster 吗?

您可以使用 map() 来获得相同的结果,而无需 for 循环。并且排序函数是用户定义的,或者是 lambda,或者是 sorted:

的一部分
  1. 首先创建一个排序函数:

    >>> def mysort(it):
    ...   return sorted(it, key=itemgetter(0))
    ...
    >>> list(map(mysort, z))
    [[[4, 4], [5, 3], [10, 2]], [[4, 2], [5, 8], [7, 6]]]
    
  2. 同上,但用 lambda 代替:

    >>> list(map(lambda it: sorted(it, key=itemgetter(0)), z))
    [[[4, 4], [5, 3], [10, 2]], [[4, 2], [5, 8], [7, 6]]]
    
  3. 同一个partial:

    >>> from functools import partial
    >>> psort = partial(sorted, key=itemgetter(0))
    >>> list(map(psort, z))
    [[[4, 4], [5, 3], [10, 2]], [[4, 2], [5, 8], [7, 6]]]
    

    或就地定义的部分:

    >>> list(map(partial(sorted, key=itemgetter(0)), z))
    [[[4, 4], [5, 3], [10, 2]], [[4, 2], [5, 8], [7, 6]]]
    
  4. 你的问题有一个列表的列表,而不是一个 3d numpy 数组。对于面向 numpy 的解决方案,see this answer.

仅供参考,(2) 和 (3b) 是 roughly equivalent, but have their differences
在选项1-3中,我更喜欢(2)中的lambda。

对于 numpy one liner 你可以使用 numpy.argsort:

import numpy as np

a = np.array([[[10,  2],
               [ 5,  3],
               [ 4,  4]],
              [[ 7,  6],
               [ 4,  2],
               [ 5,  8]]])

a[np.arange(0,2)[:,None], a[:,:,0].argsort()]
array([[[ 4,  4],
        [ 5,  3],
        [10,  2]],
       [[ 4,  2],
        [ 5,  8],
        [ 7,  6]]])

对于这么小的数组来说,这需要大约相同的时间,但扩大规模会带来相当大的改进,例如:

from operator import itemgetter

a = np.random.randint(0,10, (2,100_000,2))

%timeit a[np.arange(0,2)[:,None], a[:,:,0].argsort()]
26.9 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit [sorted(x,key=itemgetter(0)) for x in a]
327 ms ± 6.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

为什么不简单地:np.sort(z,axis=1)

import numpy as np

z = np.array([[[10,  2],
               [ 5,  3],
               [ 4,  4]],
              [[ 7,  6],
               [ 4,  2],
               [ 5,  8]]])

print(np.sort(z,axis=1))

[[[ 4  2]
  [ 5  3]
  [10  4]]

 [[ 4  2]
  [ 5  6]
  [ 7  8]]]