张量流索引

Tensorflow indexing

我想按所有列的行范围选择一个张量。

类似于:

x[10:20,:]

第 10 到 20 行以及所有列。

我试过使用:

tf.gather_nd

有什么方法可以做到这一点?

我相信你正在寻找 tf.slice

因此假设您有一个 2D 张量,如您的示例所示,要获得第一维的 10:20,您将执行以下操作: tf.slice(x, begin = [10,0], size = [10, x.get_shape().as_list()[1]])

注意:begin 是 0 索引,size 是 1 索引。所以我给出的尺寸会给你从 starting/begin 点开始的长度为 10 的维度 1 和所有维度 2.

Tensorflow 支持 numpy 样式索引:x[10:20,:]

一个例子:

 x = tf.Variable(tf.truncated_normal(shape=(100, 100)))
 y = x[10:20,]

 sess = tf.InteractiveSession()
 tf.global_variables_initializer().run()
 y.eval().shape
 #output
 #(10, 100)