TensorFlow 二值图像分类:预测数据集中每个图像的每个 class 的概率

TensorFlow Binary Image Classification: Predict Probability of each class for each image in data set

我正在为二值图像分类构建一个 TensorFlow 模型。我有两个标签“好”和“坏” 我希望模型应该为数据集中的每个图像输出,无论该图像是好是坏以及概率是多少

例如,如果我提交 1.jpg 并且假设它是“好”图像。那么模型应该预测 1.jpg 好,概率为 100%,坏,概率为 0%。

到目前为止,我已经能够想出以下内容

model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(input_shape, input_shape, 3)),
  tf.keras.layers.MaxPool2D(2,2),
  #
  tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
  tf.keras.layers.MaxPool2D(2,2),
  #
  tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
  tf.keras.layers.MaxPool2D(2,2),
  ##
  tf.keras.layers.Flatten(),
  ##
  tf.keras.layers.Dense(512, activation='relu'),
  ##
  tf.keras.layers.Dense(1, activation='sigmoid')
])

上述模型输出的形状是 1 x 1。但我认为这不符合我的目的。

我就是这样编译模型的

 model.compile(loss='binary_crossentropy',
          optimizer=RMSprop(lr=0.001),
          metrics=['accuracy'])
 model_fit = model.fit(train_dataset,
                  steps_per_epoch=3,
                  epochs=30,
                  validation_data=validation_dataset)

非常感谢任何帮助。

你不必让你的模型输出“好”和“坏”作为标签,相反,你可以独立地输出每一个的概率,换句话说,图像好的概率和概率形象不好。将最后一层的输出大小设置为 2。因此,您的模型现在将输出一个二维向量,其中 [1.0, 0.0] 表示 100% 好,0% 坏,[0.0, 1.0] 表示 0% 好和 100% 坏。使用二元交叉熵作为训练的损失函数。当然,你必须类似地标记你的训练数据,所以如果你有一个好的训练示例,将其标记为 [1.0, 0.0] 因为你 100% 确定它是好的,如果你有一个不好的训练示例,则将其标记为[0.0, 1.0] 因为你也 100% 确定这是一个不好的例子。

我告诉你使用二元交叉熵作为损失函数的原因是模型将学习为二维向量输出的分量输出相反的概率。所以如果它是一个好的图像,第一个分量会高,第二个分量会低,反之亦然,如果它是一个坏图像。另外,在训练之后,在进行预测时,你只取两者中概率最高的,如果概率较高的是第一个,那么它是一个“好”图像,你只使用那个概率。

如果有人正在寻找答案,下面是 python 模型生成代码

这里需要注意的一些要点是

  1. 输入图像形状为 360x360x3
  2. 最后一层的激活函数是“softmax”而不是“sigmoid
  3. 损失函数是“sparse_categorical_crossentropy”而不是“binary_crossentropy
  4. 输出的形状是 2 而不是 1

请注意#2、#3 和#4,尽管我正试图提出一个二值图像分类模型。我的最终目标是将此模型转换为 TensorFlow Lite 版本并在 Android 应用程序中使用 TensorFlow Lite 模型。

之前,当我在最后一层使用“sigmoid”和“binary_crossentropy”作为损失函数时,最后一层的输出形状不能大于1。

因此,当我在 Android 应用程序中使用从该 TensorFlow 模型生成的 Lite 模型时,我遇到了如下所述的错误

"Cannot find an axis to label. A valid axis to label should have size larger than 1"

通过#2、#3 和#4 中提到的更改,生成的 Lite 模型在 Android 中运行良好。

import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.optimizers import RMSprop


print("version")
print(tf.__version__)

train = ImageDataGenerator(rescale=1/255)
validation = ImageDataGenerator(rescale=1/255)

input_shape = 360
train_dataset = train.flow_from_directory('container_images/train/',
                                          target_size=(input_shape,input_shape),
                                          batch_size=3,
                                          classes=['good', 'bad'],
                                          class_mode='binary')

validation_dataset = train.flow_from_directory('container_images/validation/',
                                          target_size=(input_shape,input_shape),
                                          batch_size=3,
                                          classes=['good', 'bad'],
                                          class_mode='binary')

print(train_dataset.class_indices)
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(input_shape, input_shape, 3)),
    tf.keras.layers.MaxPool2D(2,2),
    #
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPool2D(2,2),
    #
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPool2D(2,2),
    ##
    tf.keras.layers.Flatten(),
    ##
    tf.keras.layers.Dense(512, activation='relu'),
    ##
    tf.keras.layers.Dense(2, activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=RMSprop(lr=0.001),
              metrics=['accuracy'])
model_fit = model.fit(train_dataset,
                      steps_per_epoch=3,
                      epochs=30,
                      validation_data=validation_dataset)