Tensorflow:如果灰度,则将图像转换为 rgb

Tensorflow: Convert image to rgb if grayscale

我有一个 rgb 和灰度图像的数据集。在遍历数据集时,我想检测图像是否为灰度图像,以便我可以将其转换为 rgb。我想使用 tf.shape(image) 来检测图像的尺寸。对于 rgb 图像,我得到类似 [1, 100, 100, 3] 的内容。对于灰度图像,函数 returns 例如 [1, 100, 100]。我想使用 len(tf.shape(image)) 来检测它的长度是 4 (=rgb) 还是长度 3 (=grayscale)。那没有用。

这是我目前无法使用的代码:

def process_image(image):
    # Convert numpy array to tensor
    image = tf.convert_to_tensor(image, dtype=tf.uint8)
    # Take care of grayscale images
    dims = len(tf.shape(image))
    if dims == 3:
        image = np.expand_dims(image, axis=3)
        image = tf.image.grayscale_to_rgb(image)
    return image

是否有其他方法可以将灰度图像转换为 rgb?

你可以使用这样的函数:

import tensorflow as tf

def process_image(image):
    image = tf.convert_to_tensor(image, dtype=tf.uint8)
    image_rgb =  tf.cond(tf.rank(image) < 4,
                         lambda: tf.image.grayscale_to_rgb(tf.expand_dims(image, -1)),
                         lambda: tf.identity(image))
    # Add shape information
    s = image.shape
    image_rgb.set_shape(s)
    if s.ndims is not None and s.ndims < 4:
        image_rgb.set_shape(s.concatenate(3))
    return image_rgb

我有一个非常相似的问题,我想一次加载 rgb 和灰度图像。 Tensorflow 支持在读入图像时设置通道号。因此,如果图像具有不同数量的通道,这可能就是您正在寻找的:

# to get greyscale:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=1)

# to get rgb:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=3)

-> 您甚至可以在同一张图像和内部 tf.data.Dataset 映射中执行这两项操作!

您现在必须设置 channels 变量以匹配您需要的形状,因此所有加载的图像都将是该形状。比你可以无条件重塑。

这还允许您在 Tensorflow 中直接将灰度图像加载到 RGB。举个例子:

    >> a = Image.open(r"Path/to/rgb_img.JPG")
    >> np.array(a).shape
    (666, 1050, 3)
    >> a = a.convert('L')
    >> np.array(a).shape
    (666, 1050)
    >> b = np.array(a)
    >> im = Image.fromarray(b) 
    >> im.save(r"Path/to/now_it_is_greyscale.jpg")
    >> raw_img = tf.io.read_file(r"Path/to/now_it_is_greyscale.jpg")
    >> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=3)
    >> img.shape
    TensorShape([666, 1050, 3])
    >> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=1)
    >> img.shape
    TensorShape([666, 1050, 1])

得到ValueError: 'images' contains no shape.就用expand_animations = False!参见: