'Reduction' tf.keras.losses 中的参数

'Reduction' parameter in tf.keras.losses

根据 docsReduction 参数有 3 个值 - SUM_OVER_BATCH_SIZESUMNONE

y_true = [[0., 2.], [0., 0.]]
y_pred = [[3., 1.], [2., 5.]]

mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5

mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 2.75

经过各种尝试,我可以推断出的计算结果是:-

因此,SUM_OVER_BATCH_SIZE 只是 SUM/batch_size。那么,为什么叫SUM_OVER_BATCH_SIZESUM实际上是把整个batch的损失加起来,而SUM_OVER_BATCH_SIZE计算的是batch的平均损失。

我关于 SUM_OVER_BATCH_SIZESUM 的假设完全正确吗?

据我了解,你的假设是正确的。

如果你检查 github [keras/losses_utils.py][1] 行 260-269 您会看到它确实按预期执行。 SUM 将在批量维度中总结损失,SUM_OVER_BATCH_SIZESUM 除以总损失数(批量大小)。

def reduce_weighted_loss(weighted_losses,
                     reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
  if reduction == ReductionV2.NONE:
     loss = weighted_losses
  else:
     loss = tf.reduce_sum(weighted_losses)
     if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
        loss = _safe_mean(loss, _num_elements(weighted_losses))
  return loss

您可以通过添加一对损失为 0 的输出来轻松检查您之前的示例。

y_true = [[0., 2.], [0., 0.],[1.,1.]]
y_pred = [[3., 1.], [2., 5.],[1.,1.]]

mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5

mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 1.8333

所以,你的假设是正确的。 [1]: https://github.com/keras-team/keras/blob/v2.7.0/keras/utils/losses_utils.py#L25-L84