用 Tensorflow 2.0 中的另一个张量索引张量的第 k 个维度

Indexing k-th dimension of tensor with another tensor in Tensorflow 2.0

我有一个张量 probs,其形状为 (None, None, 110),表示 LSTM 中的 (batch_size, sequence_length, 110)。 我有另一个张量 indices,其形状为 (None, None),其中包含从 probs.

的第三维到 select 的元素索引

我想用indices索引张量probs

Numpy 等价物:

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

由于 probsshape[0]shape[1] 未知,因此 tf.meshgrid() 不是一个选项。 我找到了 tf.gathertf.gather_ndtf.batch_gather,但它们似乎都不符合我的要求。

有人知道怎么做吗?

你可以用 tf.gather_nd 这样做:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

顺便说一句,在 NumPy 中,您可以使用 np.take_along_axis 来做同样的事情:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]