Tensorflow 根据索引从每一行中删除元素(gather_nd 的否定)

Tensorflow delete element from each row based on index (negation of gather_nd)

首先,一些背景。我目前正在为我的数据输入管道编写自定义 TensorFlow 2.x 预处理函数。最后我会 map 批处理。本质上,该函数接收一批行并通过复制行并根据条件删除每行中的一个元素来生成 更大的 批次。例如,如果输入批次看起来像

[[4,  1, 10, 10,  2],
 [10, 7,  9, 10, 10],
 [6,  8, 10,  3,  5]]

然后该函数应根据没有 10 的位置生成新样本。每次出现非 10 时都会删除这些元素,例如从第一个样本(新样本)中移除 4 个,移除 1 个(另一个新样本),...,从最后一个样本中移除 5 个。从输入批次中,我们将有 9 个样本:

[[1, 10, 10, 2],
 [4, 10, 10, 2],
 [4, 1, 10, 10],
 [10, 9, 10, 10],
 [10, 7, 10, 10],
 [8, 10, 3, 5],
 [6, 10, 3, 5],
 [6, 8, 10, 5],
 [6, 8, 10, 3]]

现在开始我的活动。通过使用 tf.wheretf.gathertf.unique_with_countstf.repeat,我能够将原始行复制正确的次数:

def myFunction(data):
    # Returns a 2-column tensor, with each row
    # being the index pair...
    presentIndices = tf.where(data != 10)
    # Grab the 1st column (rows) and count how many
    # times each row appears...
    rows = tf.gather(presentIndices, indices=0, axis=1)
    _, _, counts = tf.unique_with_counts(rows)
    # Repeat each row according to counts...
    data = tf.repeat(data, repeats=counts, axis=0)
    # data now has 1st row copied 3 times, 2nd row copied twice, etc.

但是,鉴于我在 presentIndices 中有索引,我现在不知道如何从每一行中删除适当的元素。使用 numpy,我可以简单地索引 data 并相应地重新整形,但似乎 TensorFlow 没有很好的索引到多维张量的能力。

我已经研究了 tf.boolean_mask,但是我需要再次在适当的位置分配 False。我能找到的最接近的东西是 tf.gather_nd,但是 提取 给定索引的数据。相反,我本质上需要否定该功能。给定索引,提取所有数据 except at these indices.

有没有办法利用现有的 TensorFlow 函数来获得我想要的功能?

谢谢!

您可以执行以下操作。我知道这可能有点head-spinning。最简单的方法就是使用此代码作为参考来做一个示例。

def f(data):
    
    # Boolean mask where it's not 10
    a = (data != 10)
    # Repeat and reshape to n x 5 x 5
    a = tf.reshape(tf.repeat(a, 5), [-1, 5, 5])
    # Create a identity matrix of size 1 x 5 x 5
    eye = tf.reshape(tf.eye(5), [1,5,5])
    # Create a mask of size n x 5 x 5. This basically forces a to have only a single false value for each row
    # This single false element is the element to be removed
    mask = ~tf.cast(tf.reshape(tf.cast(a,'int32')* tf.cast(eye, 'int32'), [-1, 5]), 'bool')

    # Remove all the rows with all elements True. This ensures at least one element is removed from all existing rows
    mask = tf.cast(mask, 'int32') * tf.cast(~tf.reduce_all(mask, axis=1, keepdims=True), 'int32')
    mask = tf.cast(mask, 'bool')
    
    # Get the required rows and discard others and reshape
    res = tf.boolean_mask(tf.repeat(data, 5, axis=0), mask)     
    res = tf.reshape(res, [-1,4])

    return res

这会产生,

tf.Tensor(
[[ 1 10 10  2]
 [ 4 10 10  2]
 [ 4  1 10 10]
 [10  9 10 10]
 [10  7 10 10]
 [ 8 10  3  5]
 [ 6 10  3  5]
 [ 6  8 10  5]
 [ 6  8 10  3]], shape=(9, 4), dtype=int32)

您可以使用 tf.boolean_masktf.scatter_nd 为您的(重复)数据创建一个布尔向量. 首先,您创建一个索引张量来指示要屏蔽的值:

row = tf.constant([0,1,2,3,4,5,6,7,8] ,dtype = tf.int64)
mask_for_each_row = tf.stack([row ,presentIndices[: , 1]],axis = 1 )

然后在 tf.scatter_nd 方法中使用 mask_for_each_row 作为索引:

samples =tf.boolean_mask(data ,~tf.scatter_nd(mask_for_each_row , 
            tf.ones((9,),dtype = tf.bool),(9,5)))
samples = tf.reshape(samples ,(9,4))

样本张量:

      <tf.Tensor: shape=(9, 4), dtype=int32, numpy=
      array([[ 1, 10, 10,  2],
             [ 4, 10, 10,  2],
             [ 4,  1, 10, 10],
             [10,  9, 10, 10],
             [10,  7, 10, 10],
             [ 8, 10,  3,  5],
             [ 6, 10,  3,  5],
             [ 6,  8, 10,  5],
             [ 6,  8, 10,  3]])>