Tensorflow:张量的交叉索引切片
Tensorflow: cross index slicing of a tensor
我有两个如下形状的张量:
tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)
tensor2
包含范围从 0 - 105
的值,我希望使用它来切片 tensor1
的最后一个维度并获得形状 tensor3
[=18] =]
tensor3 => shape(10, 99, 99)
我试过使用:
tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)
另外,使用
tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be
# less than the rank of the tensor1 (which is 3).
我正在为此寻找类似于 numpy 的 cross_indexing 的东西。
您可以使用 tf.map_fn
:
tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)
您可以将此行视为在 tensor1
和 tensor2
的第一个维度上运行的循环,并且对于第一个维度中的每个索引 i
它应用 tf.gather
在 tensor1[i,:,:]
和 tensor2[i,:]
.
我有两个如下形状的张量:
tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)
tensor2
包含范围从 0 - 105
的值,我希望使用它来切片 tensor1
的最后一个维度并获得形状 tensor3
[=18] =]
tensor3 => shape(10, 99, 99)
我试过使用:
tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)
另外,使用
tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be
# less than the rank of the tensor1 (which is 3).
我正在为此寻找类似于 numpy 的 cross_indexing 的东西。
您可以使用 tf.map_fn
:
tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)
您可以将此行视为在 tensor1
和 tensor2
的第一个维度上运行的循环,并且对于第一个维度中的每个索引 i
它应用 tf.gather
在 tensor1[i,:,:]
和 tensor2[i,:]
.