当使用参差不齐的张量和 while 循环时,XLA 无法推断出跨步切片的编译时间常量输出形状
XLA can't deduce compile time constant output shape for strided slice when using ragged tensor and while loop
是否可以使用 experimental_compile=True
获得以下最小示例?我已经看到这个论点有一些很大的加速,因此我很想弄清楚如何让它发挥作用。谢谢!
import tensorflow as tf
print(tf.__version__)
# ===> 2.2.0-dev20200409
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
for i, tensor in enumerate(ragged_tensor):
print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
# [5. 6. 7. 8. 9.]]
# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]
# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
# [20. 21. 22. 23. 24.]]
@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
# [[{{node while/RaggedGetItem/strided_slice_4}}]]
# [[while]]
# This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]
关于 XLA 对参差不齐的张量的处理似乎有很多限制。我能想到有几个替代方案可以使您的示例正常工作,但我不知道它们是否适用于您的实际用例。一方面,您可以提前对参差不齐的维度求和,甚至可以对除第一个维度之外的所有维度求和。然而,这需要在 XLA 之外完成,因为它似乎无法编译它:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Use the sums computed before
return i + 1, running_total + ragged_sum[i]
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
您也可以只将参差不齐的张量转换为常规张量,这将用不会影响总和的零填充它。同样,目前需要使用 XLA 完成此操作:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Reduce padded tensor
return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
对于遇到此类问题的任何人,我只是注意到在 TensorFlow 2.5 上这有效(将 experimental_compile
替换为 jit_compile
):
import tensorflow as tf
print(tf.__version__)
# 2.5.0
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
for i, tensor in enumerate(ragged_tensor):
print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
# [5. 6. 7. 8. 9.]]
# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]
# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
# [20. 21. 22. 23. 24.]]
@tf.function(autograph=False, jit_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
while_loop_works()
# 2021-06-28 13:18:19.253261: I tensorflow/compiler/jit/xla_compilation_cache.cc:337] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
# <tf.Tensor: shape=(), dtype=float32, numpy=300.0>
是否可以使用 experimental_compile=True
获得以下最小示例?我已经看到这个论点有一些很大的加速,因此我很想弄清楚如何让它发挥作用。谢谢!
import tensorflow as tf
print(tf.__version__)
# ===> 2.2.0-dev20200409
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
for i, tensor in enumerate(ragged_tensor):
print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
# [5. 6. 7. 8. 9.]]
# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]
# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
# [20. 21. 22. 23. 24.]]
@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
# [[{{node while/RaggedGetItem/strided_slice_4}}]]
# [[while]]
# This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]
关于 XLA 对参差不齐的张量的处理似乎有很多限制。我能想到有几个替代方案可以使您的示例正常工作,但我不知道它们是否适用于您的实际用例。一方面,您可以提前对参差不齐的维度求和,甚至可以对除第一个维度之外的所有维度求和。然而,这需要在 XLA 之外完成,因为它似乎无法编译它:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Use the sums computed before
return i + 1, running_total + ragged_sum[i]
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
您也可以只将参差不齐的张量转换为常规张量,这将用不会影响总和的零填充它。同样,目前需要使用 XLA 完成此操作:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Reduce padded tensor
return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
对于遇到此类问题的任何人,我只是注意到在 TensorFlow 2.5 上这有效(将 experimental_compile
替换为 jit_compile
):
import tensorflow as tf
print(tf.__version__)
# 2.5.0
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
for i, tensor in enumerate(ragged_tensor):
print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
# [5. 6. 7. 8. 9.]]
# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]
# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
# [20. 21. 22. 23. 24.]]
@tf.function(autograph=False, jit_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
while_loop_works()
# 2021-06-28 13:18:19.253261: I tensorflow/compiler/jit/xla_compilation_cache.cc:337] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
# <tf.Tensor: shape=(), dtype=float32, numpy=300.0>