在具有 tf.while_loop 的函数上使用 tf.vectorized_map 的问题 (TF 2.0)

Issue with using tf.vectorized_map on a function with a tf.while_loop (TF 2.0)

包含的最小工作示例尝试向量化一个将输入加 10.0 的函数,并通过每次加 1.0(10 次)的 while 循环来实现。函数 运行 在使用 tf.map_fn 时完美运行,而在使用 tf.vectorized_map 时失败。

该函数在使用矢量化地图时不会 运行,并且错误指向 Either add a converter or set --op_conversion_fallback_to_while_loop=True,这可能 运行 更慢。 我究竟做错了什么?注意:这是我问过的同一个问题的副本 https://github.com/tensorflow/tensorflow/issues/46559

if __name__ == "__main__":
    import tensorflow as tf
    @tf.function()
    def add(a):       
        i = tf.constant(0, dtype = tf.int32)
        c = tf.constant(1., dtype = tf.float32)
        loop_index = lambda  i, c, a: i < 10
        def body(i, c, a):
            a = c + a
            i = i + 1
            return i,c, a
        i,c, a = tf.while_loop(loop_index, body, [i,c, a],\
             shape_invariants=[tf.TensorShape(()), tf.TensorShape(()),tf.TensorShape([1])], back_prop= False, parallel_iterations=1)

        return a

    counter =  tf.reshape(tf.range(0, 40, delta = 1, dtype = tf.float32), shape = [40,1])

    all_ = tf.vectorized_map(add, counter) # does not work
    # all_ = tf.map_fn(add, counter) # works as expected
    print(all_, '<-- should be [40,1] float32 tensor with elements [10., 11., ...49.]')

不幸的是,此问题只能在 conda 环境中使用 TF2.2.0 + CUDA 10.1 重现,并且没有任何问题代码。解决方法是 运行 使用不同版本的 tensorflow 的代码。这是一个协作笔记本,展示了问题受让人 ravikyram Link

的 github 问题中的相同内容