计算tensorflow 2.0中句子的填充嵌入查找中的原始序列长度

calculate original sequence length in padded embedding lookup of a sentence in tensorflow 2.0

text_tensor 是一个形状为 [None,sequence_max_length,embedding_dim] 的张量,它包含对一批序列的嵌入查找。使用零填充序列。我需要获得一个名为 text_lengths 的列表,形状为 [None](None 是批量大小),其中包含每个序列的长度,没有填充。我已经尝试了几个脚本。

我得到的最接近的是下面的代码:

 text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(text_tensor, tf.as_tensor(numpy.zeros([embedding_dim]))), dtype=tf.int32), axis=-1)

但还是算错了长度。谁能帮我解决这个问题?

如果我没理解错的话,在序列的原始长度之后,第一个轴的剩余索引的大小为 0 embedding_dim

import tensorflow as tf

# batch_size = 2, first sequence length = 1, second sequence length = 3
data = [[[1, 1, 1, 1],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [0, 0, 0, 0]]]

with tf.compat.v1.Session() as sess:
    tensor = tf.constant(data, dtype=tf.int32)
    check = tf.reduce_all(tf.not_equal(tensor, 0), axis=-1)
    lengths = tf.reduce_sum(tf.cast(check, tf.int32), axis=-1)
    print(sess.run(lengths))

输出

[1 3]