如何根据列拆分 numpy 数组?

How to split a numpy array based on a column?

我有一个如下形式的数组:

[[ 1. ,    2.,     3.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.3,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.2,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.1,    2.1,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.5,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.7,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.3,    2.2,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.6,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.8,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.4,    2.3,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.3,    3.7,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.9,    3.2,    2.,     3.2,    3.2,    4.2  ],
 [ 1.5,    2.1,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.89,   2.3,    3.5,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.7,    3.2,    2.,     3.2,    3.231,  4.2  ],
 [ 1.9,    2.2,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.22,   3.6,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.8,    3.2,    2.,     3.66,   3.2,    4.2  ],
 [ 1.89,   2.3,    1.,     1.,     3.,     3.,     4.   ],
 [ 1.3,    2.99,   3.7,    3.,     3.3,    3.3,    4.3  ],
 [ 1.2,    2.9,    3.2,    2.,     3.34,   3.2,    4.2  ]]

我想根据第四列将这个数组拆分成多个子数组。 IE。我想要一个第四列等于 1 的子数组,另一个第四列等于 2 的子数组,等等。我事先不知道第四列中所有可能的值是什么。

比如第四列为1对应的子数组为:

[[ 1.     2.     3.     1.     3.     3.     4.   ],
 [ 1.1    2.1    1.     1.     3.     3.     4.   ],
 [ 1.3    2.2    1.     1.     3.     3.     4.   ],
 [ 1.4    2.3    1.     1.     3.     3.     4.   ],
 [ 1.5    2.1    1.     1.     3.     3.     4.   ],
 [ 1.9    2.2    1.     1.     3.     3.     4.   ],
 [ 1.89   2.3    1.     1.     3.     3.     4.   ]]

查看 docs 将数组拆分为多个子数组。

numpy.hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

假设你有一个 4x4 数组 A:

array([[  0.,   1.,   2.,   3.],
   [  4.,   5.,   6.,   7.],
   [  8.,   9.,  10.,  11.],
   [ 12.,  13.,  14.,  15.]])

split = numpy.hsplit(A,4) = 

[array([[  0.],
   [  4.],
   [  8.],
   [ 12.]]), array([[  1.],
   [  5.],
   [  9.],
   [ 13.]]), array([[  2.],
   [  6.],
   [ 10.],
   [ 14.]]), array([[  3.],
   [  7.],
   [ 11.],
   [ 15.]])]

制作数组列表:

y = [x[x[:,3]==k] for k in np.unique(x[:,3])]

您可以使用 numpy.argsortnumpy.array_splitnumpy.diffnumpy.where:

O(NlogN) 时间内完成此操作
>>> indices = np.argsort(arr[:, 3])
>>> arr_temp = arr[indices]
>>> np.array_split(arr_temp, np.where(np.diff(arr_temp[:,3])!=0)[0]+1)
[array([[ 1.  ,  2.  ,  3.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.89,  2.3 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.1 ,  2.1 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.9 ,  2.2 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.3 ,  2.2 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.5 ,  2.1 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ],
       [ 1.4 ,  2.3 ,  1.  ,  1.  ,  3.  ,  3.  ,  4.  ]]), array([[ 1.2  ,  2.8  ,  3.2  ,  2.   ,  3.66 ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.7  ,  3.2  ,  2.   ,  3.2  ,  3.231,  4.2  ],
       [ 1.2  ,  2.9  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.9  ,  3.2  ,  2.   ,  3.34 ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.8  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.7  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ],
       [ 1.2  ,  2.2  ,  3.2  ,  2.   ,  3.2  ,  3.2  ,  4.2  ]]), array([[ 1.3 ,  2.3 ,  3.6 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.89,  2.3 ,  3.5 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.3 ,  3.5 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.22,  3.6 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.3 ,  3.3 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.99,  3.7 ,  3.  ,  3.3 ,  3.3 ,  4.3 ],
       [ 1.3 ,  2.3 ,  3.7 ,  3.  ,  3.3 ,  3.3 ,  4.3 ]])]

我以某种方式转变了@ashwini-chaudhary 的想法,returns 感兴趣的指数用于以后的迭代。所以我想我会分享它:

def split_idx_by_dim(dim_array):
    """Returns a sequence of arrays of indices of elements sharing the same value in dim_array"""
    idx = np.argsort(dim_array)
    sorted_cl_ids = dim_array[idx]
    split_idx = np.array_split(idx, np.where(np.diff(sorted_cl_ids) != 0)[0] + 1)
    return split_idx