tensorflow2.3加载模型失败有什么解决办法吗?

Is there any solution for failing to load model in tensorflow2.3?

我尝试使用 tf.keras.models.load_model 在 tensorflow 2.3 中加载保存的模型。 但是,我得到了同样的错误 https://github.com/tensorflow/tensorflow/issues/41535

看来是个重要的功能。但是这个问题仍然没有解决。有谁知道是否有任何替代方法可以实现相同的结果?

我找到了一种在 tensorflow 2.3 中加载自定义模型的替代方法。您需要进行以下更改。我将通过一些代码快照来解释

  • for __init__() 自定义模型。之前,

    def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
        layers = []
        layer_configs = {}
        if 'layers' in kwargs.keys():
            layer_configs = kwargs['layers']
        for config in layer_configs:
            layer = tf.keras.layers.deserialize(config)
            layers.append(layer)
        super(custom_model, self).__init__(layers)  # custom_model is your custom model class
        self.mask_ratio = mask_ratio
        self.hyperparam = hyperparam
        ...
    

    之后,

    def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
        super(custom_model, self).__init__()  # custom_model is your custom model class
        self.mask_ratio = mask_ratio
        self.hyperparam = hyperparam
        ...
    
  • 在自定义模型中定义两个函数class

    def get_config(self):
        config = {
            'mask_ratio': self.mask_ratio,
            'hyperparam': self.hyperparam
        }
        base_config = super(custom_model, self).get_config()
        return dict(list(config.items()) + list(base_config.items()))
    @classmethod
    def from_config(cls, config):
        #config = cls().get_config()
        return cls(**config)
    
  • 完成训练后,使用'h5'格式保存模型

    model.save(file_path, save_format='h5')
    
  • 最后加载模型如下代码,

    model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'custom_model': custom_model})