如何指定在 PyTorch 中获取元素的轴
How to specify the axis over which to get the elements in PyTorch
我一直在寻找关于 SO 的答案,但老实说,我什至不知道如何表达问题。
给定张量,如何获取指定轴上的索引
这是一个简单的例子
indices = torch.tensor([2, 4, 5])
# Depending on the context I need either
y = x[indices]
# Or
y = x[:, indices]
# Or any other axis
y = x[:, :, :, indices]
这是一个我需要这种行为的用例
def remove_weaklings(x: Tensor, percentage: float, axis: int) -> Tensor:
all_axes = set(range(x.ndim)) - set([axis])
y = x
# Sumup over all other axes
for a in all_axes:
y = y.sum(axis=a, keepdim=True)
y = y.squeeze()
# Get the sorted list
_, idx = torch.sort(y)
# Get only a fraction of the list
idx = idx[:int(percentage * len(idx))]
# Get the indices over some axis
# !!! This part not sure how to solve !!!
return x.get_over_axis(axis=axis, indices=idx)
我想你在找 topk
:
def remove_weaklings(x, percentage, axis):
k = int(percentage * x.shape[axis])
return torch.topk(x, k, dim=axis)
如果您想要更通用的解决方案,可以使用 numpy
的 slicing:
def get_over_axis(x, indices, axis):
i = []
for d_ in range(x.dim()):
if d_ == axis:
i.append(indices)
else:
i.append(np.s_[:])
return x[tuple(i)]
我一直在寻找关于 SO 的答案,但老实说,我什至不知道如何表达问题。
给定张量,如何获取指定轴上的索引
这是一个简单的例子
indices = torch.tensor([2, 4, 5])
# Depending on the context I need either
y = x[indices]
# Or
y = x[:, indices]
# Or any other axis
y = x[:, :, :, indices]
这是一个我需要这种行为的用例
def remove_weaklings(x: Tensor, percentage: float, axis: int) -> Tensor:
all_axes = set(range(x.ndim)) - set([axis])
y = x
# Sumup over all other axes
for a in all_axes:
y = y.sum(axis=a, keepdim=True)
y = y.squeeze()
# Get the sorted list
_, idx = torch.sort(y)
# Get only a fraction of the list
idx = idx[:int(percentage * len(idx))]
# Get the indices over some axis
# !!! This part not sure how to solve !!!
return x.get_over_axis(axis=axis, indices=idx)
我想你在找 topk
:
def remove_weaklings(x, percentage, axis):
k = int(percentage * x.shape[axis])
return torch.topk(x, k, dim=axis)
如果您想要更通用的解决方案,可以使用 numpy
的 slicing:
def get_over_axis(x, indices, axis):
i = []
for d_ in range(x.dim()):
if d_ == axis:
i.append(indices)
else:
i.append(np.s_[:])
return x[tuple(i)]