从 PyTorch N 维张量中过滤掉 NaN 值

Filter out NaN values from a PyTorch N-Dimensional tensor

这个问题很相似。不同的是,我想将相同的概念应用于2维或更高维的张量。

我有一个看起来像这样的张量:

import torch

tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
 [float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
 [2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]

我想找到最 pythonic / PyTorch 的方法来过滤(删除)张量的行 nan。通过沿第一个(0th 轴)过滤此 tensor,我想获得一个 filtered_tensor,如下所示:

>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
 [2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]

使用 PyTorch 的 isnan()any() 一起使用获得的布尔掩码对 tensor 的行进行切片,如下所示:

filtered_tensor = tensor[~torch.any(tensor.isnan(),dim=1)]

请注意,这将删除任何包含 nan 值的行。如果您只想删除所有值为 nan 的行,请将 torch.any 替换为 torch.all

对于 N-dimensional 张量,您可以将除第一个暗淡之外的所有暗淡变平并应用与上述相同的过程:

#Flatten:
shape = tensor.shape
tensor_reshaped = tensor.reshape(shape[0],-1)
#Drop all rows containing any nan:
tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)]
#Reshape back:
tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:])