PyTorch:torch.arange 中的不一致行为

PyTorch: Inconsistent behavior in torch.arange

我正在使用 Google Colab,当我 运行 以下代码时:

from torch import tensor, arange

print( arange(0.0, 1.2, 0.2) )
print( arange(tensor(0.0), tensor(1.2), tensor(0.2)) )

我得到输出:

tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000])
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 1.2000])

arange(0.0,1.1,0.1) 也存在这种差异,但 arange(0.0,1.5,0.5) 没有差异。

为什么看似相似的代码会产生不同的结果,我如何预测这种情况何时会发生?

这是一个数值精度问题:(

默认情况下,Python 以双精度存储浮点数(又名 float64),而 PyTorch 默认使用 float32

如果你尝试:

tensor(1.2).item()
tensor(1.2).dtype  # torch.float32

你会得到 1.2000000476837158,并且 arange1.2 不同。在这种特定情况下,如果您尝试:

import torch
from torch import tensor, arange

print(arange(tensor(0.0), tensor(1.2, dtype=torch.float64), tensor(0.2)))

你会得到你所期望的,但是即使 float64 最终也会有一些精度问题。无论如何,由于 Python 也使用双精度,在这种情况下来回转换标量不会有这个问题。

不知道你能不能预料到。