打印 PyTorch 张量的精确值(浮点精度)

Print exact value of PyTorch tensor (floating point precision)

我正在尝试打印 torch.FloatTensor 如:

a = torch.FloatTensor(3,3)
print(a)

这样我可以获得如下值:

0.0000e+00  0.0000e+00  3.2286e-41
1.2412e-40  1.2313e+00  1.6751e-37
2.6801e-36  3.5873e-41  9.4463e+21

但我想得到更准确的值,比如小数点后10位:

0.1234567891+01

对于其他 python 数字对象,我可以通过以下方式获得它:

print('{:.10f}'.format(a))

但是在张量的情况下,我得到这个错误:

TypeError: unsupported format string passed to torch.FloatTensor.__format__

如何打印更精确的张量值?

您可以设置精度选项:

torch.set_printoptions(precision=10)

documentation page上有更多的格式化选项,和numpy的非常相似

作为旁注,此功能取自 numpy。 PyTorch 之所以聪明的原因之一是因为他们从 numpy 中吸取了很多好的想法。

但是,在 numpy 中,默认精度为 8,在 PyTorch 中,默认精度为 4。