如何通过索引获取张量流 tf?

how to get a tensorflow tf by index?

object_for_each_prior = tf.constant([1 for i in range(8732)])
-><tf.Tensor: shape=(8732,), dtype=int32, numpy=array([1, 1, 1, ..., 1, 1, 1], dtype=int32)>

那如果我想得到位置1148,1149

prior_for_each_object = tf.constant([1148,1149])
object_for_each_prior[prior_for_each_object]

然后我得到如下错误

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1148, 1149], dtype=int32)>

如果我想通过索引获取张量的数字,我应该如何处理它?

使用tf.gather_nd函数索引张量。 示例如下:

>>> object_for_each_prior = tf.constant([1 for i in range(8732)])
>>> prior_for_each_object = tf.gather_nd(object_for_each_prior, indices=[[1148], [1149]])
>>> prior_for_each_object
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1])>
>>> prior_for_each_object.numpy()
array([1, 1])

参考 this 文档以了解有关 tf.gatherr_nd 的更多信息。