如何在 Tensorflow 中为 segment_ids 张量内每次出现的项目获取唯一 ID

How to get a unique id for every occurrence of an item inside a tensor of segment_ids in Tensorflow

假设 x 包含段 ID,我想为每个段 ID 内的每个项目提供一个唯一的 ID。这需要在tensorflow操作

中执行
x = tf.constant([1, 1, 2, 2, 3, 3, 4, 1])

需要输出:

[0, 1, 0, 1, 0, 1, 0, 2]

只计算每个段 ID 中的每个项目。我不想使用 py_func.

尝试 tf.unique_with_countstf.while_loop:

import tensorflow as tf

x = tf.constant([1, 1, 2, 2, 3, 3, 4, 1])

unique, _, count = tf.unique_with_counts(x)

i = tf.constant(0)
result = tf.zeros_like(x)
c = lambda i, result, unique, count: tf.less(i, tf.shape(unique)[0])
b = lambda i, r, u, c: (tf.add(i, 1), tf.tensor_scatter_nd_update(r, tf.where(tf.equal(u[i], x)), tf.range(c[i])), u, c)

_, result, _, _ = tf.while_loop(c, b, loop_vars=[i, result, unique, count])

print(result)
# tf.Tensor([0 1 0 1 0 1 0 2], shape=(8,), dtype=int32)
point_ids = tf.zeros_like(x)
for i in range(n_pillars):
    indicies = tf.cast(tf.where(tf.equal(x, i)), tf.int32)
    updates = tf.range(len(indicies))
    point_ids += tf.scatter_nd(indicies, updates, shape=tf.shape(idx))