如何在给定 TF2.1 中每一行的起始索引的情况下对张量进行切片?

How to slice a Tensor given the starting indices for each each row in TF2.1?

给定一些(至少是二维的)输入,例如:

inputs = [['a0', 'a1', 'a2', 'a3', 'a4'],
          ['b0', 'b1', 'b2', 'b3', 'b4'],
          ['c0', 'c1', 'c2', 'c3', 'c4']]

...以及另一个索引输入和标量 window 大小:

indices = [2, 3, 0]  # representing the starting positions (2nd dimension)
window_size = 2      # fixed-width of each window

如何在 Tensorflow 2 中从这些索引开始获取 windows?我首先考虑使用像 inputs[,start:start+window_size] 这样的常规切片,但这不适用,因为这只允许对所有行使用一个起始索引,并且不支持每行使用不同的索引。


此示例的预期输出为:

output = [['a2', 'a3'],
          ['b3', 'b4'],
          ['c0', 'c1']]

我提供一种向量化方法。矢量化方法将明显快于 tf.map_fn().

import tensorflow as tf

inputs = tf.constant([['a0', 'a1', 'a2', 'a3', 'a4'],
                      ['b0', 'b1', 'b2', 'b3', 'b4'],
                      ['c0', 'c1', 'c2', 'c3', 'c4']])
indicies = tf.constant([2, 3, 0])
window_size = 2

start_index = tf.sequence_mask(indicies,inputs.shape[-1])
# tf.Tensor(
# [[ True  True False False False]
#  [ True  True  True False False]
#  [False False False False False]], shape=(3, 5), dtype=bool)
end_index = tf.sequence_mask(indicies+window_size,inputs.shape[-1])
# tf.Tensor(
# [[ True  True  True  True False]
#  [ True  True  True  True  True]
#  [ True  True False False False]], shape=(3, 5), dtype=bool)

index = tf.not_equal(start_index,end_index)
# tf.Tensor(
# [[False False  True  True False]
#  [False False False  True  True]
#  [ True  True False False False]], shape=(3, 5), dtype=bool)

result = tf.reshape(tf.boolean_mask(inputs,index),
                    indicies.get_shape().as_list()+[window_size])
print(result)
# tf.Tensor(
# [[b'a2' b'a3']
#  [b'b3' b'b4']
#  [b'c0' b'c1']], shape=(3, 2), dtype=string)