tensorflow2.3加载模型失败有什么解决办法吗?
Is there any solution for failing to load model in tensorflow2.3?
我尝试使用 tf.keras.models.load_model 在 tensorflow 2.3 中加载保存的模型。
但是,我得到了同样的错误
https://github.com/tensorflow/tensorflow/issues/41535
看来是个重要的功能。但是这个问题仍然没有解决。有谁知道是否有任何替代方法可以实现相同的结果?
我找到了一种在 tensorflow 2.3 中加载自定义模型的替代方法。您需要进行以下更改。我将通过一些代码快照来解释
for __init__()
自定义模型。之前,
def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
layers = []
layer_configs = {}
if 'layers' in kwargs.keys():
layer_configs = kwargs['layers']
for config in layer_configs:
layer = tf.keras.layers.deserialize(config)
layers.append(layer)
super(custom_model, self).__init__(layers) # custom_model is your custom model class
self.mask_ratio = mask_ratio
self.hyperparam = hyperparam
...
之后,
def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
super(custom_model, self).__init__() # custom_model is your custom model class
self.mask_ratio = mask_ratio
self.hyperparam = hyperparam
...
在自定义模型中定义两个函数class
def get_config(self):
config = {
'mask_ratio': self.mask_ratio,
'hyperparam': self.hyperparam
}
base_config = super(custom_model, self).get_config()
return dict(list(config.items()) + list(base_config.items()))
@classmethod
def from_config(cls, config):
#config = cls().get_config()
return cls(**config)
完成训练后,使用'h5'格式保存模型
model.save(file_path, save_format='h5')
最后加载模型如下代码,
model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'custom_model': custom_model})
我尝试使用 tf.keras.models.load_model 在 tensorflow 2.3 中加载保存的模型。 但是,我得到了同样的错误 https://github.com/tensorflow/tensorflow/issues/41535
看来是个重要的功能。但是这个问题仍然没有解决。有谁知道是否有任何替代方法可以实现相同的结果?
我找到了一种在 tensorflow 2.3 中加载自定义模型的替代方法。您需要进行以下更改。我将通过一些代码快照来解释
for
__init__()
自定义模型。之前,def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs): layers = [] layer_configs = {} if 'layers' in kwargs.keys(): layer_configs = kwargs['layers'] for config in layer_configs: layer = tf.keras.layers.deserialize(config) layers.append(layer) super(custom_model, self).__init__(layers) # custom_model is your custom model class self.mask_ratio = mask_ratio self.hyperparam = hyperparam ...
之后,
def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs): super(custom_model, self).__init__() # custom_model is your custom model class self.mask_ratio = mask_ratio self.hyperparam = hyperparam ...
在自定义模型中定义两个函数class
def get_config(self): config = { 'mask_ratio': self.mask_ratio, 'hyperparam': self.hyperparam } base_config = super(custom_model, self).get_config() return dict(list(config.items()) + list(base_config.items())) @classmethod def from_config(cls, config): #config = cls().get_config() return cls(**config)
完成训练后,使用'h5'格式保存模型
model.save(file_path, save_format='h5')
最后加载模型如下代码,
model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'custom_model': custom_model})