如何在 TensorFlow 中捕获第一个匹配元素

How to catch the first matching element in TensorFlow

例如输入:

a = tf.constant([1, 2, 3, 1, 2, 1, 4])

我要:

a_mask = [True, True, True, False, False, False, True]

我不知道这是否回答了你的问题:

l = list(a.numpy())

# Init a set and the mask
se = set()
a_mask = []

for x in l:
  if x in se:
    a_mask.append(False)
  else :
    a_mask.append(True)
  se.add(x)

尝试将 tf.math.unsorted_segment_mintf.tensor_scatter_nd_update 结合使用:

import tensorflow as tf

a = tf.constant([1, 2, 3, 1, 2, 1, 4])
v, i = tf.unique(a)
indices = tf.math.unsorted_segment_min(tf.range(tf.shape(a)[0]), i, tf.shape(v)[0])
updates = tf.ones_like(indices, dtype=tf.bool)
mask = tf.tensor_scatter_nd_update(tf.zeros_like(a, dtype=tf.bool), tf.expand_dims(indices, axis=-1), updates)
print(mask)
tf.Tensor([ True  True  True False False False  True], shape=(7,), dtype=bool)