Tensorflow 2:加载后不再能够跟踪子类模型的属性

Tensorflow 2: No Longer Able to Track Attributes of a Subclassed Model After Loaded

这是我在 Tensorflow 2.5 中实现的子类模型:

from tensorflow.keras import Model, Input
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.applications.densenet import preprocess_input
from tensorflow.keras.layers import Conv2D, Flatten, Dense
from tensorflow.random import uniform
from tensorflow.keras.models import load_model 

class Detector(Model):
    
    def __init__(self, num_classes=3, name="DenseNet201"):
        super(Detector, self).__init__(name=name)
        self.feature_extractor = DenseNet201(
            include_top=False,
            weights="imagenet",
        )
        self.feature_extractor.trainable = False
        self.flatten_layer = Flatten()
        self.prediction_layer = Dense(num_classes, activation=None)

    def call(self, inputs):
        x = preprocess_input(inputs)
        self.extracted_feature = self.feature_extractor(x, training=False)
        x = self.flatten_layer(self.extracted_feature)
        x = self.prediction_layer(x)
        return x

在测试我的代码时,我发现了一些让我很困惑的事情。

detector = Detector()
print(detector.extracted_feature)

这给了我一个错误:AttributeError: 'Detector' object has no attribute 'extracted_feature',这是可以理解的,因为我从来没有首先调用过模型。调用模型后,Detector 对象现在具有属性 extracted_feature。所以下面的代码将无任何错误地执行:

image_tensor_1 = uniform(shape=(1, 600, 600, 3))
y_hat = detector(image_tensor_1)
print(detector.extracted_feature.shape)

但是,在尝试通过 运行 detector.save("my_model") 保存模型并将模型加载回新变量 new_detector = load_model("my_model") 之后。我收到一个错误 运行 下面的代码:

image_tensor_2 = uniform(shape=(1, 600, 600, 3))
y_hat = new_detector(image_tensor_2)
print(new_detector.extracted_feature.shape)

AttributeError: 'Detector' 对象没有属性 'extracted_feature'.

self.extracted_feature 是我用来计算梯度的。我需要继续跟踪它,这样渐变就不会是 None。我应该怎么做才能访问属性 extracted_feature?

你可以这样做

    def call(self, inputs):
        x = preprocess_input(inputs)
        extracted_feature = self.feature_extractor(x, training=False)
        x = self.flatten_layer(extracted_feature)
        x = self.prediction_layer(x)
        return extracted_feature, x

正在检查

image_tensor_1 = uniform(shape=(1, 32, 32, 3))
detector = Detector()
ex_feat, y_hat = detector(image_tensor_1)
print(ex_feat.shape)
(1, 1, 1, 512)

保存并重新加载。

detector.save("my_model")
new_detector = load_model("my_model")

image_tensor_2 = uniform(shape=(1, 32, 32, 3))
ex_feat, y_hat = new_detector(image_tensor_2)
print(ex_feat.shape)
(1, 1, 1, 512)

仅供参考,如果您想从基础模型中获取中间层输出,那么您可能需要在 __init__ 方法中以这种方式初始化您的基础模型。