使用 Tensorflow 将图像从 RGB 转换为调色板中的索引
Convert an image from RGB to index in palette using Tensorflow
我想将 RGB 图像转换为单通道图像,其值为调色板(已提取)中的整数索引。
一个例子:
import tensorflow as tf
# image shape (height=2, width=2, channels=3)
image = tf.constant([
[
[1., 1., 1.], [1., 0., 0.]
],
[
[0., 0., 1.], [1., 0., 0.]
]
])
# palette is a tensor with the extracted colors
# palette shape (num_colors_in_palette, 3)
palette = tf.constant([
[1., 0., 0.],
[0., 0., 1.],
[1., 1., 1.]
])
indexed_image = rgb_to_indexed(image, palette)
# desired result: [[2, 0], [1, 0]]
# result shape (height, width)
我可以想出几种在纯 python 中实现 rgb_to_indexed(image, palette)
的方法,但我很难找到 如何以 Tensorflow 方式实现它(对 AutoGraph 使用 @tf.funtion 并避免 for 循环),仅(或大部分)使用矢量化操作。
编辑 1:显示示例 python/numpy 代码
如果代码不需要使用 Tensorflow,非矢量化实现可以是:
import numpy as np
def rgb_to_indexed(image, palette):
result = np.ndarray(shape=[image.shape[0], image.shape[1]])
for i, row in enumerate(image):
for j, color in enumerate(row):
index, = np.where(np.all(palette == color, axis=1))
result[i, j] = index
return result
indexed_image = rgb_to_indexed(image.numpy(), palette.numpy())
# indexed_image is [[2, 0], [1, 0]]
我使用了另一个问题 () 中描述的技术,并将其从 Numpy 改编为 Tensorflow。它是完全矢量化的,执行速度非常快。
首先在 Numpy 中(矢量化):
def rgb_to_indexed(image, palette):
original_shape = image.shape
# flattens the image to the shape (height*width, channels)
flattened_image = image.reshape(original_shape[0]*original_shape[1], -1)
num_pixels, num_channels = flattened_image.shape[0], flattened_image.shape[1]
# creates a mask of pixel and color matches and reduces it to two lists of indices:
# a) color in the palette, and b) pixel in the image
indices = flattened_image == palette[:, None]
row_sums = indices.sum(axis=2)
color_indices, pixel_indices = np.where(row_sums == num_channels)
# sets -42 as the default value in case some color is not in the palette,
# then replaces the values for which some index has been found in the palette
INDEX_OF_COLOR_NOT_FOUND = -42
indexed_image = np.ones(num_pixels, dtype="int64") * -1
indexed_image[pixel_indices] = color_indices
# reshapes to "deflatten" the indexed_image and give it a single channel (the index)
indexed_image = indexed_image.reshape([*original_shape[0:2]])
return indexed_image
然后我对 Tensorflow 的翻译:
@tf.function
def rgba_to_indexed(image, palette):
original_shape = tf.shape(image)
# flattens the image to have (height*width, channels)
# so it has the same rank as the palette
flattened_image = tf.reshape(image, [original_shape[0]*original_shape[1], -1])
num_pixels, num_channels = tf.shape(flattened_image)[0], tf.shape(flattened_image)[1]
# does the mask magic but using tensorflow ops
indices = flattened_image == palette[:, None]
row_sums = tf.reduce_sum(tf.cast(indices, "int32"), axis=2)
results = tf.cast(tf.where(row_sums == num_channels), "int32")
color_indices, pixel_indices = results[:, 0], results[:, 1]
pixel_indices = tf.expand_dims(pixel_indices, -1)
# fills with default value then updates the palette color indices of the pixels
# with colors present in the palette
INDEX_OF_COLOR_NOT_FOUND = -42
indexed_image = tf.fill([num_pixels], INDEX_OF_COLOR_NOT_FOUND)
indexed_image = tf.tensor_scatter_nd_add(
indexed_image,
pixel_indices,
color_indices - INDEX_OF_COLOR_NOT_FOUND,
tf.shape(indexed_image))
# reshapes the image back to (height, width)
indexed_image = tf.reshape(indexed_image, [original_shape[0], original_shape[1]])
return indexed_image
我想将 RGB 图像转换为单通道图像,其值为调色板(已提取)中的整数索引。
一个例子:
import tensorflow as tf
# image shape (height=2, width=2, channels=3)
image = tf.constant([
[
[1., 1., 1.], [1., 0., 0.]
],
[
[0., 0., 1.], [1., 0., 0.]
]
])
# palette is a tensor with the extracted colors
# palette shape (num_colors_in_palette, 3)
palette = tf.constant([
[1., 0., 0.],
[0., 0., 1.],
[1., 1., 1.]
])
indexed_image = rgb_to_indexed(image, palette)
# desired result: [[2, 0], [1, 0]]
# result shape (height, width)
我可以想出几种在纯 python 中实现 rgb_to_indexed(image, palette)
的方法,但我很难找到 如何以 Tensorflow 方式实现它(对 AutoGraph 使用 @tf.funtion 并避免 for 循环),仅(或大部分)使用矢量化操作。
编辑 1:显示示例 python/numpy 代码
如果代码不需要使用 Tensorflow,非矢量化实现可以是:
import numpy as np
def rgb_to_indexed(image, palette):
result = np.ndarray(shape=[image.shape[0], image.shape[1]])
for i, row in enumerate(image):
for j, color in enumerate(row):
index, = np.where(np.all(palette == color, axis=1))
result[i, j] = index
return result
indexed_image = rgb_to_indexed(image.numpy(), palette.numpy())
# indexed_image is [[2, 0], [1, 0]]
我使用了另一个问题 (
首先在 Numpy 中(矢量化):
def rgb_to_indexed(image, palette):
original_shape = image.shape
# flattens the image to the shape (height*width, channels)
flattened_image = image.reshape(original_shape[0]*original_shape[1], -1)
num_pixels, num_channels = flattened_image.shape[0], flattened_image.shape[1]
# creates a mask of pixel and color matches and reduces it to two lists of indices:
# a) color in the palette, and b) pixel in the image
indices = flattened_image == palette[:, None]
row_sums = indices.sum(axis=2)
color_indices, pixel_indices = np.where(row_sums == num_channels)
# sets -42 as the default value in case some color is not in the palette,
# then replaces the values for which some index has been found in the palette
INDEX_OF_COLOR_NOT_FOUND = -42
indexed_image = np.ones(num_pixels, dtype="int64") * -1
indexed_image[pixel_indices] = color_indices
# reshapes to "deflatten" the indexed_image and give it a single channel (the index)
indexed_image = indexed_image.reshape([*original_shape[0:2]])
return indexed_image
然后我对 Tensorflow 的翻译:
@tf.function
def rgba_to_indexed(image, palette):
original_shape = tf.shape(image)
# flattens the image to have (height*width, channels)
# so it has the same rank as the palette
flattened_image = tf.reshape(image, [original_shape[0]*original_shape[1], -1])
num_pixels, num_channels = tf.shape(flattened_image)[0], tf.shape(flattened_image)[1]
# does the mask magic but using tensorflow ops
indices = flattened_image == palette[:, None]
row_sums = tf.reduce_sum(tf.cast(indices, "int32"), axis=2)
results = tf.cast(tf.where(row_sums == num_channels), "int32")
color_indices, pixel_indices = results[:, 0], results[:, 1]
pixel_indices = tf.expand_dims(pixel_indices, -1)
# fills with default value then updates the palette color indices of the pixels
# with colors present in the palette
INDEX_OF_COLOR_NOT_FOUND = -42
indexed_image = tf.fill([num_pixels], INDEX_OF_COLOR_NOT_FOUND)
indexed_image = tf.tensor_scatter_nd_add(
indexed_image,
pixel_indices,
color_indices - INDEX_OF_COLOR_NOT_FOUND,
tf.shape(indexed_image))
# reshapes the image back to (height, width)
indexed_image = tf.reshape(indexed_image, [original_shape[0], original_shape[1]])
return indexed_image