如何指定在 PyTorch 中获取元素的轴

How to specify the axis over which to get the elements in PyTorch

我一直在寻找关于 SO 的答案,但老实说,我什至不知道如何表达问题。

给定张量,如何获取指定轴上的索引

这是一个简单的例子

indices = torch.tensor([2, 4, 5])

# Depending on the context I need either
y = x[indices]

# Or
y = x[:, indices]

# Or any other axis
y = x[:, :, :, indices]

这是一个我需要这种行为的用例


def remove_weaklings(x: Tensor, percentage: float, axis: int) -> Tensor:
    all_axes = set(range(x.ndim)) - set([axis])
    y = x
    # Sumup over all other axes
    for a in all_axes:
        y = y.sum(axis=a, keepdim=True)
    y = y.squeeze()
    # Get the sorted list
    _, idx = torch.sort(y)
    # Get only a fraction of the list
    idx = idx[:int(percentage * len(idx))]
    # Get the indices over some axis
    # !!! This part not sure how to solve !!!
    return x.get_over_axis(axis=axis, indices=idx)  

我想你在找 topk:

def remove_weaklings(x, percentage, axis):
  k = int(percentage * x.shape[axis])
  return torch.topk(x, k, dim=axis)

如果您想要更通用的解决方案,可以使用 numpyslicing:

def get_over_axis(x, indices, axis):
  i = []
  for d_ in range(x.dim()):
    if d_ == axis:
      i.append(indices)
    else:
      i.append(np.s_[:])
  return x[tuple(i)]