在排序的多维 numpy 数组中沿轴查找最后一个非 NaN 值

Find last non-NaN value along axis in sorted multi-dimensional numpy array

我正在查看一些 3D 海洋温度数据(时间、深度、经度、纬度),并想提取最低深度的值以创建海底温度的 2D 地图。

海底是一个掩码,它沿深度轴创建某种排序数组,所有 NaN 值都集中在轴 1 的末端。

复制此代码的一些示例代码:

import numpy as np

A=np.random.rand(6,50,300,360)*100
A.ravel()[np.random.choice(A.size, 10000000, replace=False)] = np.nan
A.sort(axis=1)

然后,在 之后,我得到一个数组,其中包含沿轴 1 的最终非 NaN 元素的索引:

lv=(~np.isnan(A)).sum(axis=1)-1

现在棘手的部分是使用 lvA 的轴 1 中提取值(我要提取的元素数组) . 到目前为止,我最好的方法(确实有效)是创建一个适当大小的空数组并按元素填充它:

B=np.zeros(lv.shape,dtype=np.float32)
for i in range(t):
    for j in range(y):
        for k in range(x):
            B[i,j,k]=A[i,lv[i,j,k],j,k]

但是这很慢;对于我希望使用它的数据量(价值很多 TB)来说,这是不合理的。

关于如何简化最后阶段的任何想法(如 ,但对于 numpy)? 我在想一些事情(虽然我意识到这甚至没有意义):

B=A[:,lv[:],:,:]

我也试过 np.take、np.take_along_axis 和 np.choose 的变体,但都没有成功。

提前感谢您的任何建议!

Numpy 的 take_along_axis 应该可以解决这个问题。最后一步可以表示为:

B = np.take_along_axis(A, lv[:,None,:,:], axis=1).squeeze()