序列的损失函数(在 Tensorflow 2.0 中)

Loss function for sequences (in Tensorflow 2.0)

我正在研究将句子从英语翻译成德语的问题。 所以最终输出是一个德语序列,我需要检查我的预测有多好。

我在tensorflow教程中找到了以下损失函数:

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

但是我不知道这个函数是做什么的。我知道(也许我错了)我们不能以直接的方式对序列使用 SparseCategoricalCrossentropy,我们必须进行某种操作。 但是例如在上面的代码中,我看到 SparseCategoricalCrossentropy 以直接的方式用于序列输出。为什么?

mask 变量有什么作用? 你能解释一下代码吗?

编辑: 教程- https://www.tensorflow.org/tutorials/text/nmt_with_attention

mask = tf.math.logical_not(tf.math.equal(real, 0)) 中的

mask 正在处理 PADDING

所以,在你的批次中你会有不同长度的句子,你会做 0 填充以使所有的句子长度相等(想想 I have an apple v/s It's a good day to play football in the sun)

但是,在损失计算中包含 0 填充部分没有意义 - 因此,它首先查看具有 0 的索引,然后再使用乘法使他们的损失贡献为 0.