如果目标是集合,如何定义损失函数或如何优化?

How to define the loss function or how to optimize if the target is a set?

我使用全连接网络从编码器的最后状态获取整个单词分布。

例如词汇表中有5个词。

P = [0.1, 0.1, 0.2, 0.2, 0,4]

ground truth 是这个训练数据的单词集。

我从 5 个词中抽取 3 个词,如果目标集包含 3 个词,那么我希望 P 中的 3 个词的概率增加,对于这种状态。

如果 3 个词中的一个不在目标集中,那么我希望 P 中的词的概率降低,对于这种状态。

所以我写了这些代码:

reward = [0,0,0]

假设前3个词是从P中采样出来的,而这3个词中只有前2个在目标集中。第三个词不在目标集中。那么

reward = [1,1,-1]

然后我计算 reward 的负和和点积,并采样 3 P2=[0.1, 0.1, 0.2] 作为损失

loss = -sum(reward * P2.log())

但是我没有得到结果:可以从每个州的词汇表中选择概率最高的词。

SQLNet的等式中找到答案: