tensorflow 断言元素 tf.where

tensorflow assert element tf.where

我有 2 个矩阵,形状为:

pob.shape = (2,49,20)  
rob.shape = np.zeros((2,49,20))  

并且我想获取值为 !=0 的 pob 元素的索引。所以在 numpy 中我可以这样做:

x,y,z = np.where(pob!=0)

例如:

x = [2,4,7]  
y = [3,5,5]  
z = [3,5,6]

我想更改 rob 的值:

rob[x1,y1,:] = np.ones((20))

我如何使用 tensorflow 对象执行此操作? 我尝试使用 tf.where 但我无法从张量 obj

中获取索引值

你可以使用tf.range()tf.meshgrid()创建索引矩阵,然后使用tf.where()和你的条件来获得满足它的索引。然而,接下来会出现棘手的部分:您无法根据 TF (my_tensor[my_indices] = my_values).

中的索引轻松地为张量赋值。

您的问题 ("for all (i,j,k), if pob[i,j,k] != 0 then rob[i,j] = 1") 的解决方法如下:

import tensorflow as tf

# Example values for demonstration:
pob_val = [[[0, 0, 0], [1, 0, 0], [1, 0, 1]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]]]
pob = tf.constant(pob_val)
pob_shape = tf.shape(pob)
rob = tf.zeros(pob_shape)

# Get the mask:
mask = tf.cast(tf.not_equal(pob, 0), tf.uint8)

# If there's at least one "True" in mask[i, j, :], make all mask[i, j, :] = True:
mask = tf.cast(tf.reduce_max(mask, axis=-1, keepdims=True), tf.bool)
mask = tf.tile(mask, [1, 1, pob_shape[-1]])

# Apply mask:
rob = tf.where(mask, tf.ones(pob_shape), rob)

with tf.Session() as sess:
    rob_eval = sess.run(rob)
    print(rob_eval)
    # [[[0. 0. 0.]
    #   [1. 1. 1.]
    #   [1. 1. 1.]]
    #
    #  [[1. 1. 1.]
    #   [0. 0. 0.]
    #   [0. 0. 0.]]]