从 TF model Zoo 导入模型并训练

importing model from TF model Zoo and training

参考本教程https://www.tensorflow.org/tutorials/images/transfer_learning,我创建并训练了一个 resnet 模型

preprocess_input = tf.keras.applications.resnet50.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.resnet50.ResNet50(
                    include_top=False, weights='imagenet',input_shape=IMG_SHAPE, classes=2)
prediction_layer = tf.keras.layers.Dense(1)
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

那么当我推断模型是否执行数据扩充时?我希望模型在训练期间而不是在推理期间进行数据扩充

当我将图像批量推断并一次推断一张图像时,我也会得到不同的结果。当我推断一批图像时,我总是得到准确度 1(这是一个过度拟合的模型),当我一张一张地推断图像时,我得到 2 - 4 个错误(这个数字不是恒定的,每次我得到不同的准确度)

这是我的推理代码

image_batch, label_batch = test_dataset.as_numpy_iterator().next()
class_list =['close','open']
model = tf.keras.models.load_model("shutter_model")
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

error = 0
for i in range(len(predictions)):
    if predictions[i]!= label_batch[i]:
        error+=1
print("number of errors when batch of images fed into the model: ",error)
        
print('=='*10)

error = 0
for i in range(len(image_batch)):
    img = tf.expand_dims(image_batch[i], axis=0)
    predictions = model(img)
    predictions = tf.nn.sigmoid(predictions)
    class_n = 1 if predictions[0][0] >0.5 else 0
    if label_batch[i]!= class_n:
        error+=1
print("number of errors when images fed into the model one by one: ",error)

输出

number of errors when batch of images fed into the model:  0
====================
number of errors when images fed into the model one by one:  3

我的目的是使用 Resnet5o 架构训练(从头开始或从预训练的权重)2 class 模型

在对一张图像进行推理时使用 model.predict(img) 而不是 model(img)