tf.data.Datasets.repeat(EPOCHS) 与 model.fit epochs=EPOCHS 之间的差异

Difference between tf.data.Datasets.repeat(EPOCHS) vs model.fit epochs=EPOCHS

训练时,我将 epochs 设置为迭代数据的次数。当我已经可以用 model.fit(train_dataset,epochs=EPOCHS) 做同样的事情时,我想知道 tf.data.Datasets.repeat(EPOCHS) 有什么用?

它的工作方式略有不同。

让我们选择 2 个不同的例子。

  1. dataset.repeat(20) 和 model.fit(epochs=10)
  2. dataset.repeat(10) 和 model.fit(epochs=20)

我们还假设您有一个包含 100 条记录的数据集。

如果您选择选项 1,每个时期将有 2,000 条记录。在通过模型传递 2,000 条记录后,您将“检查”模型的改进情况,您将执行 10 次。

如果选择选项 2,每个 epoch 将有 1,000 条记录。在推送 1,000 条记录后,您将评估模型的改进情况,您将执行 20 次。

在这两个选项中,您将用于训练的记录总数相同,但评估、记录等模型行为的“时间”不同。

在图像数据的情况下,

tf.data.Datasets.repeat() 可用于 tf.data.Datasets 上的数据扩充。

假设您想增加训练数据集中的图像数量,使用随机变换然后重复训练数据集 count 次并应用如下所示的随机变换

train_dataset = (
    train_dataset
    .map(resize, num_parallel_calls=AUTOTUNE)
    .map(rescale, num_parallel_calls=AUTOTUNE)
    .map(onehot, num_parallel_calls=AUTOTUNE)
    .shuffle(BUFFER_SIZE, reshuffle_each_iteration=True)
    .batch(BATCH_SIZE)
    .repeat(count=5)
    .map(random_flip, num_parallel_calls=AUTOTUNE)
    .map(random_rotate, num_parallel_calls=AUTOTUNE)
    .prefetch(buffer_size=AUTOTUNE)
)

如果没有 repeat() 方法,您必须创建数据集的副本,单独应用转换,然后连接数据集。但是使用 repeat() 简化了这一点,还利用了方法链接并拥有整洁的代码。

更多关于数据增强的信息:https://www.tensorflow.org/tutorials/images/data_augmentation#apply_augmentation_to_a_dataset