基于旧张量和二维索引的新张量
New tensor based on old tensor and 2d indices
我之前问过:
我现在有同样的问题,但需要使用二维索引张量。
我有一个大小为 [batch_size, k] 的张量列,其值介于 0 和 k-1 之间:
idx = tensor([[0,1,2,0],
[0,3,2,2],
...])
和以下矩阵:
x = tensor([[[0, 9],
[1, 8],
[2, 3],
[4, 9]],
[[0, 0],
[1, 2],
[3, 4],
[5, 6]]])
我想创建一个新的张量,其中包含索引中指定的行,按顺序排列。所以我想要:
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
目前我是这样做的:
for i, batch in enumerate(t):
t[i] = batch[col[i]]
我怎样才能更有效地做到这一点?
您可以 gather
索引,经过轻微的操作使其成为兼容的形状:
>>> new_idx = idx.unsqueeze(-1).expand_as(x)
>>> x.gather(1, new_idx)
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
您应该使用 torch gather 来实现此目的。它实际上也适用于您链接的其他问题,但这留作 reader :p
的练习
让我们称 idx
你的第一个张量和 source
第二个张量。它们各自的维度是 (B,N)
和 (B, K, p)
(在您的示例中是 p=2
),并且 idx
的所有值都在 0
和 K-1
之间。
所以要使用torch gather,我们首先需要将你的操作表达为一个嵌套的for循环。在你的情况下,你真正想要实现的是:
for b in range(B):
for i in range(N):
for j in range(p):
# This kind of nested for loops is what torch.gether actually does
target[b,i,j] = source[b, idx[b,i,j], j]
但这不起作用,因为 idx
是 2D 张量,而不是 3D 张量。好吧,没什么大不了的,让我们把它变成一个 3D 张量。我们希望它的形状为 (B, N, p)
并且在最后一个维度上实际上是恒定的。然后我们可以用调用 gather
:
来替换 for 循环
reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
target = source.gather(1, reshaped_idx)
# or : target = torch.gather(source, 1, reshaped_idx)
我之前问过:
我现在有同样的问题,但需要使用二维索引张量。
我有一个大小为 [batch_size, k] 的张量列,其值介于 0 和 k-1 之间:
idx = tensor([[0,1,2,0],
[0,3,2,2],
...])
和以下矩阵:
x = tensor([[[0, 9],
[1, 8],
[2, 3],
[4, 9]],
[[0, 0],
[1, 2],
[3, 4],
[5, 6]]])
我想创建一个新的张量,其中包含索引中指定的行,按顺序排列。所以我想要:
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
目前我是这样做的:
for i, batch in enumerate(t):
t[i] = batch[col[i]]
我怎样才能更有效地做到这一点?
您可以 gather
索引,经过轻微的操作使其成为兼容的形状:
>>> new_idx = idx.unsqueeze(-1).expand_as(x)
>>> x.gather(1, new_idx)
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
您应该使用 torch gather 来实现此目的。它实际上也适用于您链接的其他问题,但这留作 reader :p
的练习让我们称 idx
你的第一个张量和 source
第二个张量。它们各自的维度是 (B,N)
和 (B, K, p)
(在您的示例中是 p=2
),并且 idx
的所有值都在 0
和 K-1
之间。
所以要使用torch gather,我们首先需要将你的操作表达为一个嵌套的for循环。在你的情况下,你真正想要实现的是:
for b in range(B):
for i in range(N):
for j in range(p):
# This kind of nested for loops is what torch.gether actually does
target[b,i,j] = source[b, idx[b,i,j], j]
但这不起作用,因为 idx
是 2D 张量,而不是 3D 张量。好吧,没什么大不了的,让我们把它变成一个 3D 张量。我们希望它的形状为 (B, N, p)
并且在最后一个维度上实际上是恒定的。然后我们可以用调用 gather
:
reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
target = source.gather(1, reshaped_idx)
# or : target = torch.gather(source, 1, reshaped_idx)