PyTorch C++ API return 中的 randperm 不应该是默认类型为 int 的张量吗?
Shouldn't `randperm` in the PyTorch C++ API return a tensor with default type int?
当我尝试使用 C++ PyTorch API 生成具有 randperm
的置换整数索引列表时,生成的张量的元素类型为 CPUFloatType{10}
而不是整数类型:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
returns
9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
不能用于张量的索引,因为元素类型是浮点数而不是整数类型。当尝试使用 my_tensor.index(shuffled_indices)
我得到
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
环境:
- python-pytorch,Arch 上的版本 1.6.0-2 Linux
- g++ (海湾合作委员会) 10.1.0
为什么会这样?
这是因为您使用 torch 创建的任何张量的默认类型始终是 float
。如果你不想要,你必须用 TensorOptions
参数指定它 struct :
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long
当我尝试使用 C++ PyTorch API 生成具有 randperm
的置换整数索引列表时,生成的张量的元素类型为 CPUFloatType{10}
而不是整数类型:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
returns
9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
不能用于张量的索引,因为元素类型是浮点数而不是整数类型。当尝试使用 my_tensor.index(shuffled_indices)
我得到
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
环境:
- python-pytorch,Arch 上的版本 1.6.0-2 Linux
- g++ (海湾合作委员会) 10.1.0
为什么会这样?
这是因为您使用 torch 创建的任何张量的默认类型始终是 float
。如果你不想要,你必须用 TensorOptions
参数指定它 struct :
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long