如何挑选使用 lambda 函数的任意 pytorch 模型?

How does one pickle arbitrary pytorch models that use lambda functions?

我目前有一个神经网络模块:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

我正在尝试对其进行检查点,但是因为 pytorch 使用 state_dicts 进行保存,这意味着如果我使用 pytorch torch.save 等进行检查点,我将无法保存我实际使用的 lambda 函数。我真的想毫无问题地保存所有内容,然后重新加载以便稍后在 GPU 上进行训练。我目前正在使用这个:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

目前,当我检查它并保存它时,它不会抛出任何错误。

我担心当我训练它时可能会有一个微妙的错误,即使没有 exceptions/errors 被训练或可能发生意外的事情(例如,在集群中的磁盘上奇怪的保存等谁知道)。

使用 pytorch classes/nn 模型安全吗?特别是如果我们想恢复使用 GPU 进行训练?

交叉发布:

我是 dill 的作者。我使用 dill(和 klepto)来保存 classes,其中包含 lambda 函数内部经过训练的 ANN。我倾向于使用 mysticsklearn 的组合,所以我不能直接与 pytorch 对话,但我可以假设它的工作原理相同。你必须小心的地方是,如果你有一个 lambda,它包含一个指向 lambda 外部对象的指针……所以例如 y = 4; f = lambda x: x+y。这看起来很明显,但是 dill 会 pickle lambda,并且根据代码的其余部分和序列化变体,可能不会序列化 y 的值。因此,我见过很多情况,人们在某些函数(或 lambda,或 class)中序列化经过训练的估算器,然后当他们从序列化中恢​​复函数时,结果不是 "correct"。首要原因是因为函数没有被封装,所以函数产生正确结果所需的所有对象都存储在 pickle 中。但是,即使在那种情况下,您也可以获得 "correct" 结果,但是您只需要创建与腌制估算器时相同的环境(即它在周围命名空间中所依赖的所有相同值) .要点应该是,尽量确保函数中使用的所有变量都在函数中定义。这是我最近开始使用的 class 的一部分(应该在 mystic 的下一个版本中):

class Estimator(object):
    "a container for a trained estimator and transform (not a pipeline)"
    def __init__(self, estimator, transform):
        """a container for a trained estimator and transform

    Input:
        estimator: a fitted sklearn estimator
        transform: a fitted sklearn transform
        """
        self.estimator = estimator
        self.transform = transform
        self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
    def __call__(self, *x):
        "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
        import numpy as np
        return self.function(*x)

请注意,调用该函数时,它使用的所有内容(包括 np)都在周围的命名空间中定义。只要 pytorch 估计器按预期序列化(没有外部引用),那么如果您遵循上述准则,应该没问题。

是的,我认为使用 dill 来 pickle lambda 函数等是安全的。我一直在使用 torch.save 和 dill 来保存状态字典,并且在 GPU 上恢复训练没有问题,因为以及 CPU 除非模型 class 已更改。即使模型 class 发生了变化(adding/deleting 一些参数),我也可以加载状态字典,修改它,然后加载到模型中。

此外,通常情况下,人们不会保存模型对象,而只会声明指令,即恢复训练的参数值以及 hyperparameters/model 稍后获得相同模型对象的参数。

保存模型对象有时会出现问题,因为更改模型 class(代码)会使保存的对象无用。如果您根本不打算更改模型 class/code,因此模型对象不会更改,那么保存对象可能会很好,但通常不建议 pickle 模块对象。

这不是个好主意。如果您这样做,那么如果您的代码更改为不同的 github 存储库,那么将很难恢复您花费大量时间训练的模型。恢复那些或再培训所花费的周期是不值得的。我建议改用 pytorch 的方式,只保存他们在 pytorch 中推荐的权重。