创建 tf.data.Dataset 对象时 repeat() 有什么用?

What is the use of repeat() when creating a tf.data.Dataset object?

我正在复现 TensorFlow 的 Time series forecasting 教程的代码。

他们使用 tf.data 对数据集进行洗牌、批处理和缓存。更准确地说,他们执行以下操作:

BATCH_SIZE = 256
BUFFER_SIZE = 10000

train_univariate = tf.data.Dataset.from_tensor_slices((x_train_uni, y_train_uni))
train_univariate = train_univariate.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()

val_univariate = tf.data.Dataset.from_tensor_slices((x_val_uni, y_val_uni))
val_univariate = val_univariate.batch(BATCH_SIZE).repeat()

我不明白为什么他们使用 repeat(),更不明白为什么他们不指定 repeat 的 count 参数。让这个过程无限期重复的意义何在?以及算法如何读取无限大数据集中的所有元素?

正如在 tensorflow federated for image classification 的教程中所见,repeat 方法用于使用数据集的重复,这也将指示 训练的时期数 .

所以使用 .repeat(NUM_EPOCHS) 其中 NUM_EPOCHS 是训练的轮数。