Tensorflow:沿张量轴的最大值的二进制掩码
Tensorflow: binary mask of max values along tensor axis
如果我有一个 N 维张量,我想创建另一个值为 0 和 1 的张量(具有相同的形状),其中 1 在某个维度上与原始张量中的最大元素处于相同的位置.
我的一个限制是我只想获得沿该轴的 第一个最大元素,以防重复。
为了简化,我将使用较少的维度。
>>> x = tf.constant([[7, 2, 3],
[5, 0, 1],
[3, 8, 2]], dtype=tf.float32)
>>> tf.reduce_max(x, axis=-1)
tf.Tensor([7. 5. 8.], shape=(3,), dtype=float32)
我想要的是:
tf.Tensor([1. 0. 0.],
[1. 0. 0.],
[0. 1. 0.], shape=(3,3), dtype=float32)
我尝试过的(并意识到是错误的):
>>> tf.cast(tf.equal(x, tf.reduce_max(x, axis=-1, keepdims=True)), dtype=tf.float32)
# works fine when there are no duplicates
tf.Tensor([[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]], shape=(3, 3), dtype=float32)
>>> y = tf.zeros([3,3])
>>> tf.cast(tf.equal(y, tf.reduce_max(y, axis=-1, keepdims=True)), dtype=tf.float32)
# fails when there are multiple identical values across dimension
tf.Tensor([[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]], shape=(3, 3), dtype=float32)
编辑:已解决
tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1, direction='DESCENDING'), 1), 0), tf.float32)
您可以使用双精度 tf.argsort()
来获取元素沿轴 1 的排名顺序并获得最大排名。这将最大值的 last
实例作为最高排名。让我们以重复元素为例 -
x = tf.constant([[7, 2, 3], #max is 7
[5, 0, 5], #max is 5 but duplicate in same row
[7, 8, 7]]) #max is 8 but shares 7 with first row too
tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1), 1), x.shape[0]-1), tf.int64)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 0, 0],
[0, 0, 1],
[0, 1, 0]], dtype=int32)>
如果我有一个 N 维张量,我想创建另一个值为 0 和 1 的张量(具有相同的形状),其中 1 在某个维度上与原始张量中的最大元素处于相同的位置.
我的一个限制是我只想获得沿该轴的 第一个最大元素,以防重复。
为了简化,我将使用较少的维度。
>>> x = tf.constant([[7, 2, 3],
[5, 0, 1],
[3, 8, 2]], dtype=tf.float32)
>>> tf.reduce_max(x, axis=-1)
tf.Tensor([7. 5. 8.], shape=(3,), dtype=float32)
我想要的是:
tf.Tensor([1. 0. 0.],
[1. 0. 0.],
[0. 1. 0.], shape=(3,3), dtype=float32)
我尝试过的(并意识到是错误的):
>>> tf.cast(tf.equal(x, tf.reduce_max(x, axis=-1, keepdims=True)), dtype=tf.float32)
# works fine when there are no duplicates
tf.Tensor([[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]], shape=(3, 3), dtype=float32)
>>> y = tf.zeros([3,3])
>>> tf.cast(tf.equal(y, tf.reduce_max(y, axis=-1, keepdims=True)), dtype=tf.float32)
# fails when there are multiple identical values across dimension
tf.Tensor([[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]], shape=(3, 3), dtype=float32)
编辑:已解决
tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1, direction='DESCENDING'), 1), 0), tf.float32)
您可以使用双精度 tf.argsort()
来获取元素沿轴 1 的排名顺序并获得最大排名。这将最大值的 last
实例作为最高排名。让我们以重复元素为例 -
x = tf.constant([[7, 2, 3], #max is 7
[5, 0, 5], #max is 5 but duplicate in same row
[7, 8, 7]]) #max is 8 but shares 7 with first row too
tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1), 1), x.shape[0]-1), tf.int64)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 0, 0],
[0, 0, 1],
[0, 1, 0]], dtype=int32)>