如何在tensorflow中从不同的对象类采样n个像素?
How to sample n pixels from different object classes in tensorflow?
问题
我想从图像中的每个实例 class 中随机采样 n
像素。
假设我的图像是 I
,宽度 w
和高度 h
。我还有一张带有标签 L
的图像,描述了与 I
形状相同的实例 classes。
当前方法
我目前的想法是首先将标签重塑为一个大的形状向量 (N_p, 1)
。然后我将它们重复 N_c
次以获得形状 (N_p, N_c)
。现在我重复一个向量 l
,它由所有形状为 (1, N_c)
到形状 (N_p, N_c)
的唯一标签组成。使这两个相等得到一个矩阵,其中一个在列 y
和行 x
中,其中对应于行 x
的像素属于 class 对应于列 y
.
下一步是将索引位置递增的矩阵与前一个矩阵连接起来。现在我可以在各行中随机洗牌该矩阵。
唯一缺少的步骤是提取该矩阵的 n*N_c
行,首先每个 class 都有一个。然后使用矩阵右侧的索引,我可以使用
tf.gather_nd
从原始图像中获取像素 I
。
问题
tensorflow中缺失的操作如何实现?即:获取 k*n 行,使得它们包含前 n 行,矩阵的每一列在矩阵的左侧部分都有一个行。
这些操作效率高吗?
有没有更简单的方法?
解决方案
对于任何感兴趣的人,这里是我的问题的解决方案以及相应的 tensorflow 代码。我在正确的轨道上,缺少的功能是
tf.nn.top_k
下面是一些示例代码,用于从图像的每个实例中采样 k 个像素 类。
import tensorflow as tf
seed = 42
width = 10
height = 6
embedding_dim = 3
sample_size = 2
image = tf.random_normal([height, width, embedding_dim], mean=0, stddev=4, seed=seed)
labels = tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
[0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.uint8)
labels = tf.cast(labels, tf.int32)
# First reshape to one vector
image_v = tf.reshape(image, [-1, embedding_dim])
labels_v = tf.reshape(labels, [-1])
# Get classes
classes, indices = tf.unique(labels_v)
# Dimensions
N_c = tf.shape(classes)[0]
N_p = tf.shape(labels_v)[0]
# Helper matrices
I = tf.tile(tf.expand_dims(indices, [-1]), [1, N_c])
C = tf.tile(tf.transpose(tf.expand_dims(tf.range(N_c), [-1])), [N_p, 1])
E = tf.cast(tf.equal(I, C), tf.int32)
P = tf.expand_dims(tf.range(N_p) + 1, [-1])
R = tf.concat([E, P], axis=1)
R_rand = tf.random_shuffle(R, seed = seed)
E_rand, P_rand = tf.split(R_rand, [N_c, 1], axis = 1)
M = tf.transpose(E_rand)
_, topInidices = tf.nn.top_k(M, k = sample_size)
topInidicesFlat = tf.expand_dims(tf.reshape(topInidices, [-1]), [-1])
sampleIndices = tf.gather_nd(P_rand, topInidicesFlat)
samples = tf.gather_nd(image_v, sampleIndices)
sess = tf.Session()
list = [image,
labels,
image_v,
labels_v,
classes,
indices,
N_c,
N_p,
I,
C,
E,
P,
R,
R_rand,
E_rand,
P_rand,
M,
topInidices,
topInidicesFlat,
sampleIndices,
samples
]
list_ = sess.run(list)
print(list_)
问题
我想从图像中的每个实例 class 中随机采样 n
像素。
假设我的图像是 I
,宽度 w
和高度 h
。我还有一张带有标签 L
的图像,描述了与 I
形状相同的实例 classes。
当前方法
我目前的想法是首先将标签重塑为一个大的形状向量 (N_p, 1)
。然后我将它们重复 N_c
次以获得形状 (N_p, N_c)
。现在我重复一个向量 l
,它由所有形状为 (1, N_c)
到形状 (N_p, N_c)
的唯一标签组成。使这两个相等得到一个矩阵,其中一个在列 y
和行 x
中,其中对应于行 x
的像素属于 class 对应于列 y
.
下一步是将索引位置递增的矩阵与前一个矩阵连接起来。现在我可以在各行中随机洗牌该矩阵。
唯一缺少的步骤是提取该矩阵的 n*N_c
行,首先每个 class 都有一个。然后使用矩阵右侧的索引,我可以使用
tf.gather_nd
从原始图像中获取像素 I
。
问题
tensorflow中缺失的操作如何实现?即:获取 k*n 行,使得它们包含前 n 行,矩阵的每一列在矩阵的左侧部分都有一个行。
这些操作效率高吗?
有没有更简单的方法?
解决方案
对于任何感兴趣的人,这里是我的问题的解决方案以及相应的 tensorflow 代码。我在正确的轨道上,缺少的功能是
tf.nn.top_k
下面是一些示例代码,用于从图像的每个实例中采样 k 个像素 类。
import tensorflow as tf
seed = 42
width = 10
height = 6
embedding_dim = 3
sample_size = 2
image = tf.random_normal([height, width, embedding_dim], mean=0, stddev=4, seed=seed)
labels = tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
[0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.uint8)
labels = tf.cast(labels, tf.int32)
# First reshape to one vector
image_v = tf.reshape(image, [-1, embedding_dim])
labels_v = tf.reshape(labels, [-1])
# Get classes
classes, indices = tf.unique(labels_v)
# Dimensions
N_c = tf.shape(classes)[0]
N_p = tf.shape(labels_v)[0]
# Helper matrices
I = tf.tile(tf.expand_dims(indices, [-1]), [1, N_c])
C = tf.tile(tf.transpose(tf.expand_dims(tf.range(N_c), [-1])), [N_p, 1])
E = tf.cast(tf.equal(I, C), tf.int32)
P = tf.expand_dims(tf.range(N_p) + 1, [-1])
R = tf.concat([E, P], axis=1)
R_rand = tf.random_shuffle(R, seed = seed)
E_rand, P_rand = tf.split(R_rand, [N_c, 1], axis = 1)
M = tf.transpose(E_rand)
_, topInidices = tf.nn.top_k(M, k = sample_size)
topInidicesFlat = tf.expand_dims(tf.reshape(topInidices, [-1]), [-1])
sampleIndices = tf.gather_nd(P_rand, topInidicesFlat)
samples = tf.gather_nd(image_v, sampleIndices)
sess = tf.Session()
list = [image,
labels,
image_v,
labels_v,
classes,
indices,
N_c,
N_p,
I,
C,
E,
P,
R,
R_rand,
E_rand,
P_rand,
M,
topInidices,
topInidicesFlat,
sampleIndices,
samples
]
list_ = sess.run(list)
print(list_)