在具有 tf.while_loop 的函数上使用 tf.vectorized_map 的问题 (TF 2.0)
Issue with using tf.vectorized_map on a function with a tf.while_loop (TF 2.0)
包含的最小工作示例尝试向量化一个将输入加 10.0 的函数,并通过每次加 1.0(10 次)的 while 循环来实现。函数 运行 在使用 tf.map_fn 时完美运行,而在使用 tf.vectorized_map 时失败。
该函数在使用矢量化地图时不会 运行,并且错误指向 Either add a converter or set --op_conversion_fallback_to_while_loop=True
,这可能 运行 更慢。
我究竟做错了什么?注意:这是我问过的同一个问题的副本 https://github.com/tensorflow/tensorflow/issues/46559。
if __name__ == "__main__":
import tensorflow as tf
@tf.function()
def add(a):
i = tf.constant(0, dtype = tf.int32)
c = tf.constant(1., dtype = tf.float32)
loop_index = lambda i, c, a: i < 10
def body(i, c, a):
a = c + a
i = i + 1
return i,c, a
i,c, a = tf.while_loop(loop_index, body, [i,c, a],\
shape_invariants=[tf.TensorShape(()), tf.TensorShape(()),tf.TensorShape([1])], back_prop= False, parallel_iterations=1)
return a
counter = tf.reshape(tf.range(0, 40, delta = 1, dtype = tf.float32), shape = [40,1])
all_ = tf.vectorized_map(add, counter) # does not work
# all_ = tf.map_fn(add, counter) # works as expected
print(all_, '<-- should be [40,1] float32 tensor with elements [10., 11., ...49.]')
不幸的是,此问题只能在 conda 环境中使用 TF2.2.0 + CUDA 10.1 重现,并且没有任何问题代码。解决方法是 运行 使用不同版本的 tensorflow 的代码。这是一个协作笔记本,展示了问题受让人 ravikyram Link
的 github 问题中的相同内容
包含的最小工作示例尝试向量化一个将输入加 10.0 的函数,并通过每次加 1.0(10 次)的 while 循环来实现。函数 运行 在使用 tf.map_fn 时完美运行,而在使用 tf.vectorized_map 时失败。
该函数在使用矢量化地图时不会 运行,并且错误指向 Either add a converter or set --op_conversion_fallback_to_while_loop=True
,这可能 运行 更慢。
我究竟做错了什么?注意:这是我问过的同一个问题的副本 https://github.com/tensorflow/tensorflow/issues/46559。
if __name__ == "__main__":
import tensorflow as tf
@tf.function()
def add(a):
i = tf.constant(0, dtype = tf.int32)
c = tf.constant(1., dtype = tf.float32)
loop_index = lambda i, c, a: i < 10
def body(i, c, a):
a = c + a
i = i + 1
return i,c, a
i,c, a = tf.while_loop(loop_index, body, [i,c, a],\
shape_invariants=[tf.TensorShape(()), tf.TensorShape(()),tf.TensorShape([1])], back_prop= False, parallel_iterations=1)
return a
counter = tf.reshape(tf.range(0, 40, delta = 1, dtype = tf.float32), shape = [40,1])
all_ = tf.vectorized_map(add, counter) # does not work
# all_ = tf.map_fn(add, counter) # works as expected
print(all_, '<-- should be [40,1] float32 tensor with elements [10., 11., ...49.]')
不幸的是,此问题只能在 conda 环境中使用 TF2.2.0 + CUDA 10.1 重现,并且没有任何问题代码。解决方法是 运行 使用不同版本的 tensorflow 的代码。这是一个协作笔记本,展示了问题受让人 ravikyram Link
的 github 问题中的相同内容