如何使用 AllenNLP 设置完全禁用 model/weight 序列化?

How to disable model/weight serialization fully with AllenNLP settings?

我希望通过使用 jsonnet 配置文件禁用标准 AllenNLP 模型训练中所有 model/state 权重的序列化。

原因是我运行使用Optuna进行自动超参数优化。 测试几十个模型很快就会填满一个驱动器。 我已经通过将 num_serialized_models_to_keep 设置为 0:

来禁用检查点
trainer +: {
    checkpointer +: {
        num_serialized_models_to_keep: 0,
    },

我不希望将 serialization_dir 设置为 None,因为我仍然想要有关中间指标等日志记录的默认行为。我只想 禁用默认模型state, training state, and best model weights writing.

除了我上面设置的选项之外,是否有任何默认的训练器或检查点选项来禁用模型权重的所有序列化?我检查了 API 文档和网页,但找不到任何内容。

如果我需要自己定义此类选项的功能,我应该在我的模型子class中覆盖 AllenNLP 中的哪些基本函数?

或者,它们是否有任何实用程序可用于在训练结束时清理中间模型状态?

编辑: 显示了自定义检查点的解决方案,但我不清楚如何让 allennlp train 为我找到此代码用例。

我希望 custom_checkpointer 可以从配置文件中调用,如下所示:

trainer +: {
    checkpointer +: {
        type: empty,
    },

调用 allennlp train --include-package <$my_package> 时加载检查点的最佳做法是什么?

我有 my_package 子目录中的子模块,例如 my_package/modelss 和 my_package/training。 我想将自定义检查点代码放在 my_package/training/custom_checkpointer.py 我的主要模型位于 my_package/models/main_model.py。 我是否必须在 main_model class 中编辑或导入任何 code/functions 才能使用自定义检查点?

您可以创建并注册一个基本上什么都不做的自定义 Checkpointer

@Checkpointer.register("empty")
class EmptyCheckpointer(Registrable):
    def maybe_save_checkpoint(
        self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int
    ) -> None:
        pass

    def save_checkpoint(
        self,
        epoch: Union[int, str],
        trainer: "allennlp.training.trainer.Trainer",
        is_best_so_far: bool = False,
        save_model_only=False,
    ) -> None:
        pass

    def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]:
        pass

    def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        return {}, {}

    def best_model_state(self) -> Dict[str, Any]:
        return {}