如何检查张量 A 的哪些元素也出现在张量 B 中,并创建布尔掩码

How to check which elements of a tensor A are present also in a tensor B, and create a boolean mask

在 tensorflow 2.0 中,我有两个整数张量 (tf.uint8),我们称它们为 A 和 B。 张量 A 的秩是任意的,而 B 是单维的。 我正在寻找的结果是获得一个布尔值张量 C (tf.bool),这样:

(举个例子假设A是等级3)

(i,j,k 是这里使用的索引,只是为了阐明概念)

为了求和,我需要检查 A 的元素是否在 B 中,并创建一个掩码 (C),说明 A 的哪些元素在 B 中,哪些不在。

视觉示例(实际上它不是代码,只是研究行为的视觉表示):

 A = [[1,2,3],
     [4,5,6]]

 B = [1,5]

 C = [[True, False, False],
     [False, True, False]]

您可以执行以下操作。我找不到以矢量化方式解决此问题的方法,因为您希望它适用于任意大小的 A。但只要 B 不是很长,它就可以正常工作。

A = tf.constant([[1,2,3],[4,5,6]])

B = tf.constant([1,5])

C = tf.math.greater(tf.reduce_sum(tf.map_fn(lambda b: tf.cast(tf.math.equal(A,b), tf.int32), B), axis=0),0)