我如何检测是否在pytorch中触发了回调?
How can i detect if a callback is triggered in pytorch?
我正在微调 BERT 模型。首先,我想冻结图层并训练一下。当触发某个回调时(假设 ReduceLROnPlateau
),我想解冻图层。我该怎么做?
恐怕 PyTorch 中的学习率调度程序不提供挂钩。查看 ReduceLROnPlateau
here 的实现,调度程序被触发时会重置两个属性(i.e.
当它识别出高原并降低学习率时):
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
基于此,您可以包装您的调度程序步骤调用,并通过检查 scheduler.cooldown_counter == scheduler.cooldown
和 scheduler.num_bad_epochs == 0
是否为真来确定 _reduce_lr
是否被触发。
我正在微调 BERT 模型。首先,我想冻结图层并训练一下。当触发某个回调时(假设 ReduceLROnPlateau
),我想解冻图层。我该怎么做?
恐怕 PyTorch 中的学习率调度程序不提供挂钩。查看 ReduceLROnPlateau
here 的实现,调度程序被触发时会重置两个属性(i.e.
当它识别出高原并降低学习率时):
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
基于此,您可以包装您的调度程序步骤调用,并通过检查 scheduler.cooldown_counter == scheduler.cooldown
和 scheduler.num_bad_epochs == 0
是否为真来确定 _reduce_lr
是否被触发。