用 Tensorflow 2.0 中的另一个张量索引张量的第 k 个维度
Indexing k-th dimension of tensor with another tensor in Tensorflow 2.0
我有一个张量 probs
,其形状为 (None, None, 110)
,表示 LSTM 中的 (batch_size, sequence_length, 110)
。
我有另一个张量 indices
,其形状为 (None, None)
,其中包含从 probs
.
的第三维到 select 的元素索引
我想用indices
索引张量probs
。
Numpy 等价物:
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
由于 probs
的 shape[0]
和 shape[1]
未知,因此 tf.meshgrid()
不是一个选项。
我找到了 tf.gather
、tf.gather_nd
和 tf.batch_gather
,但它们似乎都不符合我的要求。
有人知道怎么做吗?
你可以用 tf.gather_nd
这样做:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
顺便说一句,在 NumPy 中,您可以使用 np.take_along_axis
来做同样的事情:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]
我有一个张量 probs
,其形状为 (None, None, 110)
,表示 LSTM 中的 (batch_size, sequence_length, 110)
。
我有另一个张量 indices
,其形状为 (None, None)
,其中包含从 probs
.
我想用indices
索引张量probs
。
Numpy 等价物:
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
由于 probs
的 shape[0]
和 shape[1]
未知,因此 tf.meshgrid()
不是一个选项。
我找到了 tf.gather
、tf.gather_nd
和 tf.batch_gather
,但它们似乎都不符合我的要求。
有人知道怎么做吗?
你可以用 tf.gather_nd
这样做:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
顺便说一句,在 NumPy 中,您可以使用 np.take_along_axis
来做同样的事情:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]