Select 张量流中另一个列表中的元素

Select elemetn from another list in tensorflow

我正在研究 tensorflow,我遇到了以下问题

import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn

#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]], 
         [[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])


#find the value corresponding to label index by row
output = nn.log_softmax(logit)

我有

output = tf.Tensor(
[[[-0.74085818 -1.14085818 -1.59085818]
  [-0.97945321 -1.31945321 -1.02945321]
  [-1.43897936 -0.97897936 -0.94897936]]

 [[-1.27561467 -1.12561467 -0.92561467]
  [-1.38741927 -1.14741927 -0.83741927]
  [-0.88817684 -1.22817684 -1.21817684]]], shape=(2, 3, 3), dtype=float64)

我想 select output 中的元素通过 label 中的索引。也就是说,我最终的结果应该是

[[1.59085822 0.97945321 0.97897935]  #2, 0, 1
[1.27561462 0.83741927 1.22817683]], #0, 2, 1
shape=(2, 3), dtype=float64)

您不能直接执行此操作。实现这一目标的正确方法是首先在您的标签上应用 one hot encoding。然后从输出 logits 使用 tf.boolean_mask 到 select。

这是一个例子:

import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn

#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]], 
         [[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])


#find the value corresponding to label index by row
output = nn.log_softmax(logit)

one_hot = tf.one_hot(label, 3, dtype=tf.int32)
# <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
# array([[[0, 0, 1],
#         [1, 0, 0],
#         [0, 1, 0]],
# 
#        [[1, 0, 0],
#         [0, 0, 1],
#         [0, 1, 0]]], dtype=int32)>

result_vec = tf.boolean_mask(output, one_hot) # The result is a vector
# <tf.Tensor: shape=(6,), dtype=float64, numpy=
# array([-1.59085818, -0.97945321, -0.97897936, -1.27561467, -0.83741927,
#        -1.22817684])>

result = tf.reshape(result_vec, label.shape)

结果将是:(您是否遗漏了问题中的负号?)

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-1.59085818, -0.97945321, -0.97897936],
       [-1.27561467, -0.83741927, -1.22817684]])>