如何在 PyTorch 中做掩码均值?

How to do a masked mean in PyTorch?

这是双向 rnn 的前向传递,我想在其中获取输出特征的平均池。如您所见,我试图从计算中排除带有 pad 标记的时间步长。

def forward(self, text):
    # text is shape (B, L)
    embed = self.embed(text)
    rnn_out, _ = self.rnn(embed)  # (B, L, 2*H)
    # Calculate average ignoring the pad token
    with torch.no_grad():
        rnn_out[text == self.pad_token] *= 0
        denom = torch.sum(text != self.pad_token, -1, keepdim=True)
    feat = torch.sum(rnn_out, dim=1) / denom
    feat =  self.dropout(feat)
    return feat

反向传播因第 rnn_out[text == self.pad_token] *= 0 行而引发异常。这是它的样子:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32, 21, 128]], which is output 0 of CudnnRnnBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

正确的做法是什么?

注意:我知道我可以通过执行以下 and/or 来做到这一点:

但我想知道是否有更简洁的方法不涉及这些。

您在禁用计算图构建的上下文中修改向量(并使用 *= 就地修改它),这将对梯度计算造成严重破坏。相反,我建议如下:

mask = text != self.pad_token
denom = torch.sum(mask, -1, keepdim=True)
feat = torch.sum(rnn_out * mask.unsqueeze(-1), dim=1) / denom

也许你需要稍微调整一下这个片段,我无法测试它,因为你没有提供完整的例子,但它希望展示你可以使用的技术。