具有随机张量的压缩张量流数据集的奇怪行为

Weird behavior of zipped tensorflow dataset with random tensors

在下面的示例 (Tensorflow 2.0) 中,我们有一个包含三个元素的虚拟张量流数据集。我们在其上映射一个函数 (replace_with_float),该函数 returns 是一个随机生成的值,分为两份。正如我们所料,当我们从数据集中取元素时,第一个和第二个坐标具有相同的值。

现在,我们分别从第一个坐标和第二个坐标创建两个 "slice" 数据集,并将这两个数据集压缩回一起。切片和压缩操作似乎彼此相反,所以我希望生成的数据集与前一个数据集等效。但是,正如我们所见,现在第一个和第二个坐标是随机生成的不同值。

也许更有趣的是,如果我们通过 df = tf.data.Dataset.zip((df.map(lambda x, y: x), df.map(lambda x, y: x))),两个坐标也会有不同的值

如何解释这种行为?也许两个不​​同的图是为要压缩的两个数据集构建的并且它们是独立的运行?

import tensorflow as tf

def replace_with_float(element):
    rand = tf.random.uniform([])
    return (rand, rand)

df = tf.data.Dataset.from_tensor_slices([0, 0, 0])
df = df.map(replace_with_float)
print('Before zipping: ')
for x in df:
    print(x[0].numpy(), x[1].numpy())

df = tf.data.Dataset.zip((df.map(lambda x, y: x), df.map(lambda x, y: y)))

print('After zipping: ')
for x in df:
    print(x[0].numpy(), x[1].numpy())

示例输出:

Before zipping: 
0.08801079 0.08801079
0.638958 0.638958
0.800568 0.800568
After zipping: 
0.9676769 0.23045003
0.91056764 0.6551999
0.4647777 0.6758332

简短的回答是数据集不会缓存完整迭代之间的中间值,除非您明确要求使用 df.cache(),并且它们也不会重复删除公共输入。

所以在第二个循环中,整个流水线再次运行s。 同样,在第二个实例中,两次 df.map 调用导致 df 到 运行 两次。

添加 tf.print 有助于解释发生的情况:

def replace_with_float(element):
    rand = tf.random.uniform([])
    tf.print('replacing', element, 'with', rand)
    return (rand, rand)

我还把 lambda 放在不同的行上以避免签名警告:

first = lambda x, y: x
second = lambda x, y: y

df = tf.data.Dataset.zip((df.map(first), df.map(second)))
Before zipping: 
replacing 0 with 0.624579549
0.62457955 0.62457955
replacing 0 with 0.471772075
0.47177207 0.47177207
replacing 0 with 0.394005418
0.39400542 0.39400542

After zipping: 
replacing 0 with 0.537954807
replacing 0 with 0.558757305
0.5379548 0.5587573
replacing 0 with 0.839109302
replacing 0 with 0.878996611
0.8391093 0.8789966
replacing 0 with 0.0165234804
replacing 0 with 0.534951568
0.01652348 0.53495157

为避免重复输入问题,您可以使用单个 map 调用:

swap = lambda x, y: (y, x)
df = df.map(swap)

或者您可以使用 df = df.cache() 来避免这两种影响:

df = df.map(replace_with_float)
df = df.cache()
Before zipping: 
replacing 0 with 0.728474379
0.7284744 0.7284744
replacing 0 with 0.419658661
0.41965866 0.41965866
replacing 0 with 0.911524653
0.91152465 0.91152465

After zipping: 
0.7284744 0.7284744
0.41965866 0.41965866
0.91152465 0.91152465