如果目标是集合,如何定义损失函数或如何优化?
How to define the loss function or how to optimize if the target is a set?
我使用全连接网络从编码器的最后状态获取整个单词分布。
例如词汇表中有5个词。
P = [0.1, 0.1, 0.2, 0.2, 0,4]
ground truth 是这个训练数据的单词集。
我从 5 个词中抽取 3 个词,如果目标集包含 3 个词,那么我希望 P
中的 3 个词的概率增加,对于这种状态。
如果 3 个词中的一个不在目标集中,那么我希望 P
中的词的概率降低,对于这种状态。
所以我写了这些代码:
reward = [0,0,0]
假设前3个词是从P
中采样出来的,而这3个词中只有前2个在目标集中。第三个词不在目标集中。那么
reward = [1,1,-1]
然后我计算 reward
的负和和点积,并采样 3 P2=[0.1, 0.1, 0.2]
作为损失
loss = -sum(reward * P2.log())
但是我没有得到结果:可以从每个州的词汇表中选择概率最高的词。
从SQLNet的等式中找到答案:
我使用全连接网络从编码器的最后状态获取整个单词分布。
例如词汇表中有5个词。
P = [0.1, 0.1, 0.2, 0.2, 0,4]
ground truth 是这个训练数据的单词集。
我从 5 个词中抽取 3 个词,如果目标集包含 3 个词,那么我希望 P
中的 3 个词的概率增加,对于这种状态。
如果 3 个词中的一个不在目标集中,那么我希望 P
中的词的概率降低,对于这种状态。
所以我写了这些代码:
reward = [0,0,0]
假设前3个词是从P
中采样出来的,而这3个词中只有前2个在目标集中。第三个词不在目标集中。那么
reward = [1,1,-1]
然后我计算 reward
的负和和点积,并采样 3 P2=[0.1, 0.1, 0.2]
作为损失
loss = -sum(reward * P2.log())
但是我没有得到结果:可以从每个州的词汇表中选择概率最高的词。
从SQLNet的等式中找到答案: