在 TensorFlow 中看似不连续地洗牌后的批处理元素 2.x

Batched elements after shuffling seemingly non-consecutive in TensorFlow 2.x

我有以下简单示例:

import tensorflow as tf

tensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
tensor2 = tf.constant(value = [20, 21, 22, 23])

print(tensor1.shape)
print(tensor2.shape)

dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))

print('Original dataset')
for i in dataset:
      print(i)

dataset = dataset.repeat(3)

print('Repeated dataset')
for i in dataset:
      print(i)

它 returns,如预期:

(4, 3)
(4,)
Original dataset
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)
Repeated dataset
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)

如果我然后将 dataset 批处理为:

dataset = dataset.batch(3)

print('Batched dataset')
for i in dataset:
   print(i)

我如预期收到:

Batched dataset
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[10, 11, 12],
       [ 1,  2,  3],
       [ 4,  5,  6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 7,  8,  9],
       [10, 11, 12],
       [ 1,  2,  3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)

批处理数据集采用连续元素。

但是,当我先 suffle 然后批处理时:

dataset = dataset.shuffle(3)

print('Shuffled dataset')
for i in dataset:
  print(i)

dataset = dataset.batch(3)

print('Batched dataset')
for i in dataset:
   print(i)

批处理的元素不连续:

Shuffled dataset
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=21>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 11, 12], dtype=int32)>, <tf.Tensor: shape=(), dtype=int32, numpy=23>)
Batched dataset
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[7, 8, 9],
       [1, 2, 3],
       [1, 2, 3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 20, 20], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[10, 11, 12],
       [ 4,  5,  6],
       [ 7,  8,  9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 21, 22], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[10, 11, 12],
       [ 1,  2,  3],
       [ 4,  5,  6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)

我正在使用 Google Colab 和 TensorFlow 2.x

我的问题是:为什么在批处理之前进行洗牌会使 batch return 个不连续的元素

感谢您的任何回答。

我认为您绝对应该阅读以下内容:tf.data: Build TensorFlow input pipelines

这就是洗牌的作用。你是从这个开始的:

[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

您指定了 buffer_size=3,因此它创建了一个包含前 3 个元素的缓冲区:

[[1, 2, 3], [4, 5, 6], [7, 8, 9]]

您指定了 batch_size=3,因此它将从该样本中随机选取一个元素,并将其替换为初始缓冲区外的第一个元素。假设选择了 [1, 2, 3],您的批次现在是:

[[1, 2, 3]]

你的缓冲区现在是:

[[10, 11, 12], [4, 5, 6], [7, 8, 9]]

对于 batch=3 的第二个元素,它将从此缓冲区中随机选取。假设选择了 [7, 8, 9],您的批次现在是:

[[1, 2, 3], [7, 8, 9]]

你的缓冲区现在是:

[[10, 11, 12], [4, 5, 6]]

填充缓冲区没有新内容,因此它将随机选择这些元素之一,比如 [10, 11, 12]。您的批次现在是:

[[1, 2, 3], [7, 8, 9], [10, 11, 12]]

而下一批将仅为 [4, 5, 6],因为默认情况下,batch(drop_remainder=False)