Torch7:使用 ByteTensor 掩码对 Tensor 进行切片

Torch7: Slice Tensor using ByteTensor mask

我有两个张量:

  1. labels 是一维张量 (5000)
  2. 数据集是 4D Tensor (5000,1,32,32)

我想高效地对标签值 1 对应的标签和数据集进行切片。我成功地对标签进行了切片,但没有对数据集进行切片。

切片标签:

positive_mask = labels:eq(1)
sliced_labels = labels[positive_mask]

我尝试执行以下操作来对数据集进行切片但失败了:

sliced_dataset = dataset[positive_mask]
sliced_dataset = dataset[{positive_mask, {}, {}, {}}]
sliced_dataset = dataset:narrow(1,positive_mask)
sliced_dataset = dataset:select(1,positive_mask)

在 Torch7 中是否有优雅方法来执行此操作?

sliced_dataset = dataset:index(1, positive_mask:nonzero():squeeze())