将 1 通道掩码应用于张量流中的 3 通道张量
Apply 1 channel mask to 3 channel Tensor in tensorflow
我正在尝试将掩码(二进制,只有一个通道)应用于 RGB 图像(3 个通道,归一化为 [0, 1])。我目前的解决方案是,我将 RGB 图像拆分成它的通道,将它与蒙版相乘并再次连接这些通道:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
zero_one_mask = (output_mask + 1) / 2
# Apply mask to all channels.
channels = tf.split(3, 3, output_img)
channels = [tf.mul(c, zero_one_mask) for c in channels]
output_img = tf.concat(3, channels)
但是,这似乎效率很低,特别是因为据我所知,none 这些计算是就地完成的。有没有更有效的方法来做到这一点?
tf.mul()
operator supports numpy-style broadcasting,可以让您稍微简化和优化代码。
假设 zero_one_mask
是一个 m x n
张量,output_img
是一个 b x m x n x 3
(其中 b
是批量大小 - 我是从您在维度 3)* 上拆分 output_img
的事实推断这一点。您可以使用 tf.expand_dims()
将 zero_one_mask
广播到 channels
,方法是将其重塑为 m x n x 1
张量:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
# NOTE: Assumes `output_mask` is a 2-D `m x n` tensor.
zero_one_mask = tf.expand_dims((output_mask + 1) / 2, 2)
# Apply mask to all channels.
# NOTE: Assumes `output_img` is a 4-D `b x m x n x c` tensor.
output_img = tf.mul(output_img, zero_one_mask)
(* 如果 output_img
是 4-D b x m x n x c
(对于任意数量的通道 c
)或 3-D m x n x c
张量,这将同样有效,由于广播的工作方式。)
我正在尝试将掩码(二进制,只有一个通道)应用于 RGB 图像(3 个通道,归一化为 [0, 1])。我目前的解决方案是,我将 RGB 图像拆分成它的通道,将它与蒙版相乘并再次连接这些通道:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
zero_one_mask = (output_mask + 1) / 2
# Apply mask to all channels.
channels = tf.split(3, 3, output_img)
channels = [tf.mul(c, zero_one_mask) for c in channels]
output_img = tf.concat(3, channels)
但是,这似乎效率很低,特别是因为据我所知,none 这些计算是就地完成的。有没有更有效的方法来做到这一点?
tf.mul()
operator supports numpy-style broadcasting,可以让您稍微简化和优化代码。
假设 zero_one_mask
是一个 m x n
张量,output_img
是一个 b x m x n x 3
(其中 b
是批量大小 - 我是从您在维度 3)* 上拆分 output_img
的事实推断这一点。您可以使用 tf.expand_dims()
将 zero_one_mask
广播到 channels
,方法是将其重塑为 m x n x 1
张量:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
# NOTE: Assumes `output_mask` is a 2-D `m x n` tensor.
zero_one_mask = tf.expand_dims((output_mask + 1) / 2, 2)
# Apply mask to all channels.
# NOTE: Assumes `output_img` is a 4-D `b x m x n x c` tensor.
output_img = tf.mul(output_img, zero_one_mask)
(* 如果 output_img
是 4-D b x m x n x c
(对于任意数量的通道 c
)或 3-D m x n x c
张量,这将同样有效,由于广播的工作方式。)