使用索引列表从多维数组中向量化提取子多维数组
Vectorize extracting sub-multidimensional array from multidimensional array with list of indices
我有这个形状为 (500000,3,2,3) 的多维数组,我们称它为数据。数据基本上是 500000 组 3 个点,每个 3 个点都分成它的 x 和 y 坐标(因此是 2)。形状中的最后 3 个代表 3 个点的不同旋转。现在,我得到了这个包含 0 到 2 之间 500000 个数字的一维数组,它告诉我要保留哪个旋转,我们称它为 rot_index。我想构建一个形状为 (500000,3,2) 的多维数组,它只保留正确旋转的数据点。关于如何从原始数据数组中提取具有正确索引的数据的任何想法?我试过类似的方法,但没有用
data[:,:,:,rot_index]
编辑:
这里是一些示例数据(给出 10 组点而不是 500000)
data =
[[[[0.70846822 0.98552876 0.66736535]
[0. 0. 0. ]]
[[0.66736535 0.70846822 0.98552876]
[1.54545219 2.39798549 2.33974762]]
[[0.98552876 0.66736535 0.70846822]
[3.88519982 3.94343768 4.73773311]]]
[[[0.8132551 1.18845796 1.53004225]
[0. 0. 0. ]]
[[1.18845796 1.53004225 0.8132551 ]
[1.43211754 2.58720625 2.26386152]]
[[1.53004225 0.8132551 1.18845796]
[4.01932379 4.85106777 3.69597906]]]
[[[0.66123513 0.93651048 0.83170562]
[0. 0. 0. ]]
[[0.93651048 0.83170562 0.66123513]
[2.09747072 2.38383457 1.80188002]]
[[0.83170562 0.66123513 0.93651048]
[4.48130529 4.18571459 3.89935074]]]
[[[1.31047414 0.67740955 1.42020073]
[0. 0. 0. ]]
[[0.67740955 1.42020073 1.31047414]
[1.66061575 1.97600777 2.64656179]]
[[1.42020073 1.31047414 0.67740955]
[3.63662352 4.62256956 4.30717753]]]
[[[1.4085555 1.64177102 0.27708893]
[0. 0. 0. ]]
[[0.27708893 1.4085555 1.64177102]
[0.62154257 3.04315813 2.61848461]]
[[1.64177102 0.27708893 1.4085555 ]
[3.24002718 3.6647007 5.66164274]]]
[[[0.48080385 0.85910831 0.52342904]
[0. 0. 0. ]]
[[0.52342904 0.48080385 0.85910831]
[1.08970318 2.57102289 2.62245924]]
[[0.85910831 0.52342904 0.48080385]
[3.71216242 3.66072607 5.19348213]]]
[[[1.13610207 1.51237019 0.47256909]
[0. 0. 0. ]]
[[1.51237019 0.47256909 1.13610207]
[2.92304081 2.59328103 0.76686347]]
[[0.47256909 1.13610207 1.51237019]
[5.51632184 3.3601445 3.68990428]]]
[[[1.08397801 1.16506242 0.84703646]
[0. 0. 0. ]]
[[1.16506242 0.84703646 1.08397801]
[2.37250664 2.04419242 1.86648625]]
[[0.84703646 1.08397801 1.16506242]
[4.41669906 3.91067866 4.23899289]]]
[[[0.98734317 1.11177984 0.90283297]
[0. 0. 0. ]]
[[1.11177984 0.90283297 0.98734317]
[2.25981006 2.13666143 1.88671382]]
[[0.90283297 0.98734317 1.11177984]
[4.39647149 4.02337525 4.14652387]]]
[[[1.94118244 1.14738719 1.98251535]
[0. 0. 0. ]]
[[1.14738719 1.98251535 1.94118244]
[1.83291888 1.90183408 2.54843234]]
[[1.98251535 1.94118244 1.14738719]
[3.73475296 4.45026642 4.38135123]]]]
这是我要保留的索引列表:
rot_index = np.array([1 2 1 1 1 1 1 2 1 1])
举个例子,如果你考虑
data[0,:,:,0] = [[0.70846822 0.]
[0.66736535 1.54545219]
[0.98552876 3.88519982]]
data[0,:,:,1] = [[0.98552876 0.]
[0.70846822 2.39798549]
[0.66736535 3.94343768]]
data[0,:,:,2] = [[0.66736535 0.]
[0.98552876 2.33974762]
[0.70846822 4.73773311]]
这是同一个样本的 3 个不同的“旋转”,如果我们看 rot_index 的第一个元素,它是 1。所以我只想保留
data[0,:,:,1] = [[0.98552876 0.]
[0.70846822 2.39798549]
[0.66736535 3.94343768]]
使用 numpy advanced indexing, and under that, the specific subtopic of combining advanced and basic indexing 这应该有效(其中 data_array
是一个具有 data
的 numpy ndarray):
result = data_array[range(500000),...,rot_index]
对于您的样本数据,这将产生:
[[[0.98552876 0. ]
[0.70846822 2.39798549]
[0.66736535 3.94343768]]
[[1.53004225 0. ]
[0.8132551 2.26386152]
[1.18845796 3.69597906]]
[[0.93651048 0. ]
[0.83170562 2.38383457]
[0.66123513 4.18571459]]
[[0.67740955 0. ]
[1.42020073 1.97600777]
[1.31047414 4.62256956]]
[[1.64177102 0. ]
[1.4085555 3.04315813]
[0.27708893 3.6647007 ]]
[[0.85910831 0. ]
[0.48080385 2.57102289]
[0.52342904 3.66072607]]
[[1.51237019 0. ]
[0.47256909 2.59328103]
[1.13610207 3.3601445 ]]
[[0.84703646 0. ]
[1.08397801 1.86648625]
[1.16506242 4.23899289]]
[[1.11177984 0. ]
[0.90283297 2.13666143]
[0.98734317 4.02337525]]
[[1.14738719 0. ]
[1.98251535 1.90183408]
[1.94118244 4.45026642]]]
我有这个形状为 (500000,3,2,3) 的多维数组,我们称它为数据。数据基本上是 500000 组 3 个点,每个 3 个点都分成它的 x 和 y 坐标(因此是 2)。形状中的最后 3 个代表 3 个点的不同旋转。现在,我得到了这个包含 0 到 2 之间 500000 个数字的一维数组,它告诉我要保留哪个旋转,我们称它为 rot_index。我想构建一个形状为 (500000,3,2) 的多维数组,它只保留正确旋转的数据点。关于如何从原始数据数组中提取具有正确索引的数据的任何想法?我试过类似的方法,但没有用
data[:,:,:,rot_index]
编辑:
这里是一些示例数据(给出 10 组点而不是 500000)
data =
[[[[0.70846822 0.98552876 0.66736535]
[0. 0. 0. ]]
[[0.66736535 0.70846822 0.98552876]
[1.54545219 2.39798549 2.33974762]]
[[0.98552876 0.66736535 0.70846822]
[3.88519982 3.94343768 4.73773311]]]
[[[0.8132551 1.18845796 1.53004225]
[0. 0. 0. ]]
[[1.18845796 1.53004225 0.8132551 ]
[1.43211754 2.58720625 2.26386152]]
[[1.53004225 0.8132551 1.18845796]
[4.01932379 4.85106777 3.69597906]]]
[[[0.66123513 0.93651048 0.83170562]
[0. 0. 0. ]]
[[0.93651048 0.83170562 0.66123513]
[2.09747072 2.38383457 1.80188002]]
[[0.83170562 0.66123513 0.93651048]
[4.48130529 4.18571459 3.89935074]]]
[[[1.31047414 0.67740955 1.42020073]
[0. 0. 0. ]]
[[0.67740955 1.42020073 1.31047414]
[1.66061575 1.97600777 2.64656179]]
[[1.42020073 1.31047414 0.67740955]
[3.63662352 4.62256956 4.30717753]]]
[[[1.4085555 1.64177102 0.27708893]
[0. 0. 0. ]]
[[0.27708893 1.4085555 1.64177102]
[0.62154257 3.04315813 2.61848461]]
[[1.64177102 0.27708893 1.4085555 ]
[3.24002718 3.6647007 5.66164274]]]
[[[0.48080385 0.85910831 0.52342904]
[0. 0. 0. ]]
[[0.52342904 0.48080385 0.85910831]
[1.08970318 2.57102289 2.62245924]]
[[0.85910831 0.52342904 0.48080385]
[3.71216242 3.66072607 5.19348213]]]
[[[1.13610207 1.51237019 0.47256909]
[0. 0. 0. ]]
[[1.51237019 0.47256909 1.13610207]
[2.92304081 2.59328103 0.76686347]]
[[0.47256909 1.13610207 1.51237019]
[5.51632184 3.3601445 3.68990428]]]
[[[1.08397801 1.16506242 0.84703646]
[0. 0. 0. ]]
[[1.16506242 0.84703646 1.08397801]
[2.37250664 2.04419242 1.86648625]]
[[0.84703646 1.08397801 1.16506242]
[4.41669906 3.91067866 4.23899289]]]
[[[0.98734317 1.11177984 0.90283297]
[0. 0. 0. ]]
[[1.11177984 0.90283297 0.98734317]
[2.25981006 2.13666143 1.88671382]]
[[0.90283297 0.98734317 1.11177984]
[4.39647149 4.02337525 4.14652387]]]
[[[1.94118244 1.14738719 1.98251535]
[0. 0. 0. ]]
[[1.14738719 1.98251535 1.94118244]
[1.83291888 1.90183408 2.54843234]]
[[1.98251535 1.94118244 1.14738719]
[3.73475296 4.45026642 4.38135123]]]]
这是我要保留的索引列表:
rot_index = np.array([1 2 1 1 1 1 1 2 1 1])
举个例子,如果你考虑
data[0,:,:,0] = [[0.70846822 0.]
[0.66736535 1.54545219]
[0.98552876 3.88519982]]
data[0,:,:,1] = [[0.98552876 0.]
[0.70846822 2.39798549]
[0.66736535 3.94343768]]
data[0,:,:,2] = [[0.66736535 0.]
[0.98552876 2.33974762]
[0.70846822 4.73773311]]
这是同一个样本的 3 个不同的“旋转”,如果我们看 rot_index 的第一个元素,它是 1。所以我只想保留
data[0,:,:,1] = [[0.98552876 0.]
[0.70846822 2.39798549]
[0.66736535 3.94343768]]
使用 numpy advanced indexing, and under that, the specific subtopic of combining advanced and basic indexing 这应该有效(其中 data_array
是一个具有 data
的 numpy ndarray):
result = data_array[range(500000),...,rot_index]
对于您的样本数据,这将产生:
[[[0.98552876 0. ]
[0.70846822 2.39798549]
[0.66736535 3.94343768]]
[[1.53004225 0. ]
[0.8132551 2.26386152]
[1.18845796 3.69597906]]
[[0.93651048 0. ]
[0.83170562 2.38383457]
[0.66123513 4.18571459]]
[[0.67740955 0. ]
[1.42020073 1.97600777]
[1.31047414 4.62256956]]
[[1.64177102 0. ]
[1.4085555 3.04315813]
[0.27708893 3.6647007 ]]
[[0.85910831 0. ]
[0.48080385 2.57102289]
[0.52342904 3.66072607]]
[[1.51237019 0. ]
[0.47256909 2.59328103]
[1.13610207 3.3601445 ]]
[[0.84703646 0. ]
[1.08397801 1.86648625]
[1.16506242 4.23899289]]
[[1.11177984 0. ]
[0.90283297 2.13666143]
[0.98734317 4.02337525]]
[[1.14738719 0. ]
[1.98251535 1.90183408]
[1.94118244 4.45026642]]]