如何将张量元素与特定列的相似值组合起来?

How to combine tensor elements with similar values of specific column?

我们得到了这个 3D input_tensor,它是一个代表 (batch_size, N, 2).

的张量

我想添加分值(第二列元素),其中每批标签(第一列元素)相同。例如,给定这个张量有 3 个批次,每批次有 4 个预测和 2 个元素;我想要 required_output_tensor 作为结果。

条件:没有for loopstf.map_fn()这个答案。原因是 tf.map_fn() 在具有 TF2.X 的 GPU 上运行缓慢。

input_tensor = tf.constant([
    [
        [2., 0.7],
        [1., 0.1],
        [3., 0.4],
        [2., 0.8],
    ],
    [
        [2., 0.7],
        [1., 0.1],
        [1., 0.4],
        [4., 0.8],
    ],
    [
        [3., 0.7],
        [1., 0.1],
        [3., 0.4],
        [4., 0.8],
    ]
])

required_output_tensor = [
    [
        [2., 1.5],
        [1., 0.1],
        [3., 0.4],
    ],
    [
        [2., 0.7],
        [1., 0.5],
        [4., 0.8],
    ],
    [
        [3., 1.1],
        [1., 0.1],
        [4., 0.8],
    ]
]

编辑: 我可以看到我们将如何以参差不齐的张量结束。在这种情况下,我可以在 k=min(size(smallest_batch)) 的每批次中选择前 k 个元素,或者可以将其硬编码为 topk=2.

编辑 2: 添加额外输入以尝试建议的解决方案:

additional_input_tensor = tf.constant([
    [
        [2., 0.5],
        [1., 0.1],
        [3., 0.4],
        [2., 0.5],
    ],
    [
        [22., 0.7],
        [11., 0.2],
        [11., 0.3],
        [44., 0.8],
    ],
    [
        [3333., 0.7],
        [1111., 0.1],
        [4444., 0.4],
        [5555., 0.8],
    ],
    [
        [2., 0.9],
        [1., 0.2],
        [5., 0.3],
        [2., 0.9],
    ]
])

这个问题通常没有很好的定义,因为输入组中可能有不同数量的 non-repeated id 值,因此结果不会是密集张量。您可以尝试使用参差不齐的张量,尽管这可能会受到限制。一种选择是生成一个结果,其中输出中的每个组都有每个 id,并且不在相应输入组中的那些 id 的分数简单地设置为零。您可以这样做:

import tensorflow as tf

input_tensor = tf.constant([
    [
        [2., 0.7],
        [1., 0.1],
        [3., 0.4],
        [2., 0.8],
    ],
    [
        [2., 0.7],
        [1., 0.1],
        [1., 0.4],
        [4., 0.8],
    ],
    [
        [3., 0.7],
        [1., 0.1],
        [3., 0.4],
        [4., 0.8],
    ]
])
# Take input tensor shape
s = tf.shape(input_tensor)
# Flatten first dimensions
flat = tf.reshape(input_tensor, (-1, 2))
# Find unique id values
group_ids, group_idx = tf.unique(flat[:, 0], out_idx=s.dtype)
# Shift id indices per group in the input
num_groups = tf.reduce_max(group_idx) + 1
group_shift = tf.tile(tf.expand_dims(num_groups * tf.range(s[0]), 1), (1, s[1]))
group_idx_shift = group_idx + tf.reshape(group_shift, (-1,))
# Aggregate per group in the input
num_groups_shift = num_groups * s[0]
# Either use unsorted_segment_sum
group_sum = tf.math.unsorted_segment_sum(flat[:, 1], group_idx_shift, num_groups_shift)
# Or use bincount
group_sum = tf.math.bincount(group_idx_shift, weights=flat[:, 1],
                             minlength=num_groups_shift)
# Reshape and concatenate
group_sum_res = tf.reshape(group_sum, (s[0], num_groups))
group_ids_res = tf.tile(tf.expand_dims(group_ids, 0), (s[0], 1))
result = tf.stack([group_ids_res, group_sum_res], axis=-1)
# Sort results
result_s = tf.argsort(group_sum_res, axis=-1, direction='DESCENDING')
result_sorted = tf.gather_nd(result, tf.expand_dims(result_s, axis=-1), batch_dims=1)
print(result_sorted.numpy())
# [[[2.  1.5]
#   [3.  0.4]
#   [1.  0.1]
#   [4.  0. ]]
# 
#  [[4.  0.8]
#   [2.  0.7]
#   [1.  0.5]
#   [3.  0. ]]
# 
#  [[3.  1.1]
#   [4.  0.8]
#   [1.  0.1]
#   [2.  0. ]]]

编辑:

这里有一个使用参差不齐的张量输出的替代方法:

import tensorflow as tf

input_tensor = tf.constant([...])
# Same as before
s = tf.shape(input_tensor)
flat = tf.reshape(input_tensor, (-1, 2))
group_ids, group_idx = tf.unique(flat[:, 0], out_idx=s.dtype)
num_groups = tf.reduce_max(group_idx) + 1
group_shift = tf.tile(tf.expand_dims(num_groups * tf.range(s[0]), 1), (1, s[1]))
group_idx_shift = group_idx + tf.reshape(group_shift, (-1,))
# Apply unique again to find ids per batch
group_ids2_ref, group_idx2 = tf.unique(group_idx_shift)
group_ids2 = tf.gather(group_ids, group_ids2_ref % num_groups)
# Also can use unsorted_segment_sum here if preferred
group_sum = tf.math.bincount(group_idx2, weights=flat[:, 1])
# Count number of elements in each output group
out_sizes = tf.math.bincount(group_ids2_ref // num_groups, minlength=s[0])
# Make ragged result
group_sum_r = tf.RaggedTensor.from_row_lengths(group_sum, out_sizes)
group_ids_r = tf.RaggedTensor.from_row_lengths(group_ids2, out_sizes)
result = tf.stack([group_ids_r, group_sum_r], axis=-1)
print(*result.to_list(), sep='\n')
# [[2.0, 1.5], [1.0, 0.10000000149011612], [3.0, 0.4000000059604645]]
# [[2.0, 0.699999988079071], [1.0, 0.5], [4.0, 0.800000011920929]]
# [[3.0, 1.100000023841858], [1.0, 0.10000000149011612], [4.0, 0.800000011920929]]

不完全像你问的那样,但是如果你知道类的个数,又不想张量参差不齐,你可以用one-hot编码来添加不同的得分相同 类:

input_tensor = tf.constant([
    [
        [2., 0.7],
        [1., 0.1],
        [3., 0.4],
        [2., 0.8],
    ],
    [
        [2., 0.7],
        [1., 0.1],
        [1., 0.4],
        [4., 0.8],
    ],
    [
        [3., 0.7],
        [1., 0.1],
        [3., 0.4],
        [4., 0.8],
    ]
])


number_of_classes = 5

#first split the labels from scores
labels = tf.expand_dims(input_tensor[:,:,0], axis=-1)
scores = tf.expand_dims(input_tensor[:,:,1], axis=-1)

#get a one-hot encoding for the labels
#the way you do this would likely depend on your specific labels
#the way I do it here is not very robust (maybe use half open intervals instead)
class_indices = tf.reshape(tf.range(number_of_classes, dtype=tf.float32), shape=(1,1,number_of_classes))
one_hots = tf.cast(tf.equal(class_indices, labels), tf.float32)
print(one_hots.shape)  # (batch, N, number_of_classes)

#now multiply the one hots by the scores, and add all together
scored_one_hots = scores * one_hots
scores_per_index = tf.reduce_sum(scored_one_hots, axis=1) # (batch, number_of_classes) 
# where the second index denotes the class and contains the score for that class

# now finish up by combining these scores with the labels
# edit: of course this part too depends on how you actually did the encoding
batch_size = input_tensor.shape[0]
ordered_labels = tf.repeat(tf.expand_dims(tf.range(number_of_classes, dtype=tf.float32), axis=0), batch_size, axis=0)

result = tf.stack([ordered_labels, scores_per_index], axis=2)
print(result)

打印结果:

(3, 4, 5)
tf.Tensor(
[[[0.  0. ]
  [1.  0.1]
  [2.  1.5]
  [3.  0.4]
  [4.  0. ]]

 [[0.  0. ]
  [1.  0.5]
  [2.  0.7]
  [3.  0. ]
  [4.  0.8]]

 [[0.  0. ]
  [1.  0.1]
  [2.  0. ]
  [3.  1.1]
  [4.  0.8]]], shape=(3, 5, 2), dtype=float32)

制作 one-hot 的方式取决于标签的具体情况(tf.equals 可能不是最佳选择,但您可以使用比较等)。