pytorch 中批量规范挂钩的未知行为

Unknown behavior of hooks on batch norm in pytorch

我尝试冻结 batch_norm 层并使用正向钩子

分析它们的 inputs/outputs

对于固定的 BN 层,我无法理解为什么 hook 输出与 hook 输入再现的输出不同。

非常感谢,如果有人能帮助我

代码如下:

import torch
import torchvision
import numpy


def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

image = torch.randn((1, 3, 224, 224))
res = torchvision.models.resnet50(pretrained=True)
res.apply(set_bn_eval)
b = res(image)
layer_out = []
layer_in = []

def layer_hook(mod, inp, out):
    layer_out.append(out)
    layer_in.append(inp[0])

for name, key in res.named_modules():
    hook = key.register_forward_hook(layer_hook)
    res(image)
    hook.remove()
    out = layer_out.pop()
    inp = layer_in.pop()
    try:
        assert (out.equal(key(inp)))
    except AssertionError:
        print(name)
        break

TLDR;有些运算符只会出现在模块的forward中:比如非参数化层。

有些组件没有在子模块列表中注册。这通常是激活函数的情况,但最终取决于模块的实现。在您的例子中,ResNet's Bottleneck section as its ReLUs applied in the forward definition,就在批归一化层被调用之后。

这意味着您将使用层钩子捕获的输出将不同于您仅从模块及其输入计算的张量。

for name, module in res.named_modules():
   if name != 'bn1':
     hook = module.register_forward_hook(layer_hook)
     res(image)
     hook.remove()
     inp = layer_in.pop()
     out = layer_out.pop()
     assert out.equal(F.relu(module(inp)))

因此,实际实现起来有点棘手,因为你不能完全依赖 res.named_modules() 的内容。