使用 Tensorflow 将图像从 RGB 转换为调色板中的索引

Convert an image from RGB to index in palette using Tensorflow

我想将 RGB 图像转换为单通道图像,其值为调色板(已提取)中的整数索引。

一个例子:

import tensorflow as tf

# image shape (height=2, width=2, channels=3)
image = tf.constant([
  [
    [1., 1., 1.], [1., 0., 0.]
  ],
  [
    [0., 0., 1.], [1., 0., 0.]
  ]
])

# palette is a tensor with the extracted colors
# palette shape (num_colors_in_palette, 3) 
palette = tf.constant([
  [1., 0., 0.],
  [0., 0., 1.],
  [1., 1., 1.]
])

indexed_image = rgb_to_indexed(image, palette)
# desired result: [[2, 0], [1, 0]]
# result shape (height, width)

我可以想出几种在纯 python 中实现 rgb_to_indexed(image, palette) 的方法,但我很难找到 如何以 Tensorflow 方式实现它(对 AutoGraph 使用 @tf.funtion 并避免 for 循环),仅(或大部分)使用矢量化操作。

编辑 1:显示示例 python/numpy 代码

如果代码不需要使用 Tensorflow,非矢量化实现可以是:

import numpy as np

def rgb_to_indexed(image, palette):
    result = np.ndarray(shape=[image.shape[0], image.shape[1]])

    for i, row in enumerate(image):
        for j, color in enumerate(row):
            index, = np.where(np.all(palette == color, axis=1))
            result[i, j] = index
    return result

indexed_image = rgb_to_indexed(image.numpy(), palette.numpy())
# indexed_image is [[2, 0], [1, 0]]

我使用了另一个问题 () 中描述的技术,并将其从 Numpy 改编为 Tensorflow。它是完全矢量化的,执行速度非常快。

首先在 Numpy 中(矢量化):

def rgb_to_indexed(image, palette):
    original_shape = image.shape

    # flattens the image to the shape (height*width, channels)
    flattened_image = image.reshape(original_shape[0]*original_shape[1], -1)
    num_pixels, num_channels = flattened_image.shape[0], flattened_image.shape[1]

    # creates a mask of pixel and color matches and reduces it to two lists of indices:
    # a) color in the palette, and b) pixel in the image
    indices = flattened_image == palette[:, None]
    row_sums = indices.sum(axis=2)
    color_indices, pixel_indices = np.where(row_sums == num_channels)

    # sets -42 as the default value in case some color is not in the palette,
    # then replaces the values for which some index has been found in the palette
    INDEX_OF_COLOR_NOT_FOUND = -42
    indexed_image = np.ones(num_pixels, dtype="int64") * -1
    indexed_image[pixel_indices] = color_indices

    # reshapes to "deflatten" the indexed_image and give it a single channel (the index)
    indexed_image = indexed_image.reshape([*original_shape[0:2]])

    return indexed_image

然后我对 Tensorflow 的翻译:

@tf.function
def rgba_to_indexed(image, palette):
    original_shape = tf.shape(image)

    # flattens the image to have (height*width, channels)
    # so it has the same rank as the palette
    flattened_image = tf.reshape(image, [original_shape[0]*original_shape[1], -1])
    num_pixels, num_channels = tf.shape(flattened_image)[0], tf.shape(flattened_image)[1]

    # does the mask magic but using tensorflow ops
    indices = flattened_image == palette[:, None]
    row_sums = tf.reduce_sum(tf.cast(indices, "int32"), axis=2)
    results = tf.cast(tf.where(row_sums == num_channels), "int32")

    color_indices, pixel_indices = results[:, 0], results[:, 1]
    pixel_indices = tf.expand_dims(pixel_indices, -1)

    # fills with default value then updates the palette color indices of the pixels
    # with colors present in the palette
    INDEX_OF_COLOR_NOT_FOUND = -42
    indexed_image = tf.fill([num_pixels], INDEX_OF_COLOR_NOT_FOUND)
    indexed_image = tf.tensor_scatter_nd_add(
        indexed_image,
        pixel_indices,
        color_indices - INDEX_OF_COLOR_NOT_FOUND,
        tf.shape(indexed_image))
    
    # reshapes the image back to (height, width)
    indexed_image = tf.reshape(indexed_image, [original_shape[0], original_shape[1]])

    return indexed_image