tf.reduce_sum() uint8 的意外结果

tf.reduce_sum() unexpected result with uint8

为什么 tf.reduce_sum() 不适用于 uint8

考虑这个例子:

>>> tf.reduce_sum(tf.ones((4, 10, 10), dtype=tf.uint8))
<tf.Tensor: shape=(), dtype=uint8, numpy=144>

>>> tf.reduce_sum(tf.ones((4, 10, 10), dtype=tf.uint16))
<tf.Tensor: shape=(), dtype=uint16, numpy=400>

有人知道这是为什么吗?

docs 没有提到与 uint8 的任何不兼容。

uint8代表无符号整数,为了保持值,它得到8位。

8位,只能保存[0, 255]范围内的正数(无符号)(如果是int8可以保存[-127,+127]范围内的有符号数)。

如果要保留大于 255 的值,它只保留该数字的前 8 位,例如256在二进制中是0000 0000 1,前8位是0000 0000。因此,对于 256,您将得到 0 作为结果:

>>> tf.reduce_sum(tf.ones((1, 255), dtype=tf.uint8))
    <tf.Tensor: shape=(), dtype=uint8, numpy=255>

>>> tf.reduce_sum(tf.ones((1, 256), dtype=tf.uint8))
    <tf.Tensor: shape=(), dtype=uint8, numpy=0>

在您的情况下,预期结果为 400,但由于 uint8 无法保持高于 255 的值,因此当总和达到 256 时它将从 0 开始。因此,您看到的结果为 144,即实际上是 400-256=144.

所以,它不在 tf.reduce_sum(),它在 uint8,注意使用任何数据类型。