在张量流数据集上应用地图执行速度非常慢

Applying map on tensorflow Dataset performs very slowly

我在 python 3.8 中使用 Tensorflow 2.2。我有一个张量切片的数据集对象构建,需要在数据集的每个张量上应用一些计算,称之为 compute。为此,我使用 tf.data.Datasetmap 功能(代码见下文)。然而,与在每个张量上直接应用给定方法相比,该映射的执行速度相当慢。这是模型案例(下面的代码保存在名为 test.py 的文件中)。

import tensorflow as tf

class Test:
    def __init__(self):
        pass

    @tf.function
    def compute(self, tensor):
        # the main function that performs some computation with a tensor
        print('python.print ===> tracing compute ... ')

        res = tensor*tensor
        res = tf.signal.rfft(res)  # perform some computationally heavy task

        return res

    def apply_on_ds(self, ds):
        # mapping the compute method on a dataset
        return ds.map(lambda x: self.compute( x ) )

    @tf.function
    def apply_on_tensors(self, tensors):
        # a direct application on tensors of the compute method
        for i in tf.range(tensors.shape[0]):
            res = self.compute(tensors[i] )

为了运行上面的代码,存储在test.py,我做了如下操作

import tensorflow as tf
import time

import test

T = test.Test()
tensors = tf.random.uniform(shape=[100, 10000], dtype=tf.float32)
ds      = tf.data.Dataset.from_tensor_slices(tensors)

t1 = time.time(); x = list( T.apply_on_ds(ds) );  t2 = time.time();
# t2 - t1 equals ~1.08 sec on my computer

t1 = time.time() ;  x = T.apply_on_tensors(tensors);   t2 = time.time();
# t2 - t1 equals ~0.03 sec on my computer

Why is there such a massive gap in performance between applying the map and applying the same function used for map directly ?

当我将 map 设置的 num_parallel_callsdeterministic 参数添加到相应的 8 (我机器上的核心数)和 False~0.16 sec 中处理 运行s(相对于没有并行化的 ~1 sec)。尽管如此,这仍然比直接应用map.

中使用的方法要差得多。

我这里有明显的错误吗?我怀疑在使用地图时对图形进行了一些回溯,但是,我找不到这方面的证据。对以上内容的任何解释和改进建议将不胜感激。

我正在回答我的问题,以防万一有人遇到问题中描述的相同问题。以下内容基于对此 Github issue 的评论(更多信息可在其中找到)。

代码本身没有问题。性能上的差距是因为map操作(op)是放在CPU上执行的,而函数中使用的函数的逐一应用map 发生在 GPU 上,因此性能存在差异。要看到这一点,可以添加

tf.debugging.set_log_device_placement(True) 

代码以访问有关 Tensorflow 将其操作放置在何处的信息。 要强制 map 在 GPU 上执行,可以采用

中的 compute 方法的计算
with tf.device("/gpu:0"):    

块(参见上面的 link)。