如何为 PyTorch matmul 注册前向钩子?

How to register a forward hook for PyTorch matmul?

torch.matmul 似乎没有 nn.Module 包装器来允许按名称进行标准的前向挂钩注册。在这种情况下,矩阵乘法发生在 forward() 函数的中间。我想中间结果除了最终结果还可以通过forward()返回,比如return x, mm_res。但是收集这些额外输出的好方法是什么?

卸载 torch.matmul 输出的选项有哪些? TIA.

如果您的主要抱怨是 torch.matmul 没有模块包装器,那么制作一个如何?

class Matmul(nn.Module):
    def forward(self, *args):
        return torch.matmul(*args)

现在您可以在 Matmul 实例上注册前向挂钩

class Network(nn.Module):
    def __init__(self, ...):
        self.matmul = Matmul()
        self.matmul.register_module_forward_hook(...)
    def forward(self, x):
        y = ...
        z = self.matmul(x, y)
        ...

话虽如此,您一定不能忽视警告(红色)in the doc它应该只用于调试目的。