使用 torch.max() 时,为批次中的每个条目屏蔽某些索引
Mask certain indices for every entry in a batch, when using torch.max()
我正在对大小为 torch.Size([n, 8])
的 batch
进行增量采样。
我还有一个长度为 n 的列表 valid_indices
,其中包含对批次中的每个条目都有效的索引元组。
例如 valid_indices[0]
可能看起来像这样: (0,1,3,4,5,7)
,这表明索引 2 和 6 应该从 batch
的第一个条目中排除在 dim 1.
特别是当我使用 torch.max(batch, dim=1, keepdim=True)
.
时,我需要排除这些值
要排除的索引(如果有)可能因批次中的条目而异。
有什么想法吗?提前致谢。
我假设你正在变老
IndexError: too many indices for tensor of dimension 1
直接在张量上使用元组索引时出错。
至少这是我在执行以下行时能够重现的错误
t[0][valid_idx0]
其中 t 是大小为 (10,8) 的随机张量,valid_idx0 是具有 4 个元素的元组。
但是,当您将元组转换为如下列表时,同一行工作得很好
t[0][list(valid_idx0)]
>>> tensor([0.1847, 0.1028, 0.7130, 0.5093])
但是当涉及到将这些索引应用于二维张量时,情况会有所不同,因为我们需要保留张量的结构以进行批处理。
因此,将我们的索引转换为掩码数组是合理的。
假设我们手头有一个元组列表 valid_indices
。首先是将其转换为列表列表。
valid_idx_list = [list(tup) for tup in valid_indices]
第二件事是将它们转换为掩码数组。
masks = np.zeros((t.size()))
for i, indices in enumerate(valid_idx_list):
masks[i][indices] = 1
完成。现在我们可以应用掩码并在掩码张量上使用 torch.max。
torch.max(t*masks)
请查看我用来重现该问题的 colab notebook。
https://colab.research.google.com/drive/1BhKKgxk3gRwUjM8ilmiqgFvo0sfXMGiK?usp=sharing
我正在对大小为 torch.Size([n, 8])
的 batch
进行增量采样。
我还有一个长度为 n 的列表 valid_indices
,其中包含对批次中的每个条目都有效的索引元组。
例如 valid_indices[0]
可能看起来像这样: (0,1,3,4,5,7)
,这表明索引 2 和 6 应该从 batch
的第一个条目中排除在 dim 1.
特别是当我使用 torch.max(batch, dim=1, keepdim=True)
.
要排除的索引(如果有)可能因批次中的条目而异。
有什么想法吗?提前致谢。
我假设你正在变老
IndexError: too many indices for tensor of dimension 1
直接在张量上使用元组索引时出错。 至少这是我在执行以下行时能够重现的错误
t[0][valid_idx0]
其中 t 是大小为 (10,8) 的随机张量,valid_idx0 是具有 4 个元素的元组。
但是,当您将元组转换为如下列表时,同一行工作得很好
t[0][list(valid_idx0)]
>>> tensor([0.1847, 0.1028, 0.7130, 0.5093])
但是当涉及到将这些索引应用于二维张量时,情况会有所不同,因为我们需要保留张量的结构以进行批处理。
因此,将我们的索引转换为掩码数组是合理的。
假设我们手头有一个元组列表 valid_indices
。首先是将其转换为列表列表。
valid_idx_list = [list(tup) for tup in valid_indices]
第二件事是将它们转换为掩码数组。
masks = np.zeros((t.size()))
for i, indices in enumerate(valid_idx_list):
masks[i][indices] = 1
完成。现在我们可以应用掩码并在掩码张量上使用 torch.max。
torch.max(t*masks)
请查看我用来重现该问题的 colab notebook。
https://colab.research.google.com/drive/1BhKKgxk3gRwUjM8ilmiqgFvo0sfXMGiK?usp=sharing