PyTorch C++ API return 中的 randperm 不应该是默认类型为 int 的张量吗?

Shouldn't `randperm` in the PyTorch C++ API return a tensor with default type int?

当我尝试使用 C++ PyTorch API 生成具有 randperm 的置换整数索引列表时,生成的张量的元素类型为 CPUFloatType{10} 而不是整数类型:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES); 
cout << shuffled_indices << endl;

returns

 9
 3
 8
 6
 2
 5
 4
 7
 1
 0
[ CPUFloatType{10} ]

不能用于张量的索引,因为元素类型是浮点数而不是整数类型。当尝试使用 my_tensor.index(shuffled_indices) 我得到

terminate called after throwing an instance of 'c10::IndexError'
  what():  tensors used as indices must be long, byte or bool tensors

环境:

为什么会这样?

这是因为您使用 torch 创建的任何张量的默认类型始终是 float。如果你不想要,你必须用 TensorOptions 参数指定它 struct :

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long