如何使用 tf.data 在 CNN 中包含 rgb 和灰度图像?

How to include rgb and grayscale images in CNN using tf.data?

我正在尝试使用 rgb 图像作为输入,灰度图像作为基于 的标签图像。如何修改以下代码来定义标签图像包含一个通道?

# step 1
filenames = tf.constant(input_list)
labels = tf.constant(label_list)

# step 2: create a dataset returning slices of `filenames`
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

# step 3: parse every image in the dataset using `map`
def _parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    return image, label

dataset = dataset.map(_parse_function)
dataset = dataset.batch(2)

# step 4: create iterator and final input tensor
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
images, labels = iterator.get_next()

使用此函数加载具有不同通道数的图像:

def _parse_function(filename, channels):
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=channels)
    image = tf.image.convert_image_dtype(image, tf.float32)
    return image 

然后:

dataset = dataset.map(lambda x, y:
    (
        _parse_function(x, channels=3),
        _parse_function(y, channels=1)
    )
)