为什么我应该调用 BERT 模块实例而不是 forward 方法?
Why should I call a BERT module instance rather than the forward method?
我正在尝试使用变形金刚库中的 BERT 提取文本的矢量表示,并且偶然发现了 "BERTModel" class 的 documentation 的以下部分:
有人可以更详细地解释一下吗?前向传递对我来说很直观(毕竟我试图获得最终的隐藏状态),而且我找不到任何关于 "pre and post processing" 在这种情况下意味着什么的额外信息。
先谢谢了!
我认为这只是关于使用 PyTorch Module
的一般建议。 transformers
模块是 nn.Module
,它们需要 forward
方法。但是,不应手动调用 model.forward()
,而应调用 model()
。原因是 PyTorch 在调用模块时会在后台做一些事情。您可以在 the source code.
中找到它
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
您会看到 forward
在必要时被调用。
我正在尝试使用变形金刚库中的 BERT 提取文本的矢量表示,并且偶然发现了 "BERTModel" class 的 documentation 的以下部分:
有人可以更详细地解释一下吗?前向传递对我来说很直观(毕竟我试图获得最终的隐藏状态),而且我找不到任何关于 "pre and post processing" 在这种情况下意味着什么的额外信息。
先谢谢了!
我认为这只是关于使用 PyTorch Module
的一般建议。 transformers
模块是 nn.Module
,它们需要 forward
方法。但是,不应手动调用 model.forward()
,而应调用 model()
。原因是 PyTorch 在调用模块时会在后台做一些事情。您可以在 the source code.
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
您会看到 forward
在必要时被调用。