在将 tf.data.Dataset 传递给 fit 函数时,我们是否应该对 tf.data.Dataset 应用 repeat, batch shuffle ?

Should we apply repeat, batch shuffle to tf.data.Dataset when passing it to fit function?

在阅读了关于 tf.keras.Model.fittf.data.Dataset 的文档后,我仍然不知道,当将 tf.data.Dataset 传递给 fit 函数时,我应该调用 repeatbatch 在数据集对象上还是我应该提供 batch_sizeepochs 参数来代替?或两者?我应该对验证集应用相同的处理方式吗?

当我在这里时,我可以 shuffle fit 之前的数据集吗? (似乎这是一个明显的是) 如果是这样,在调用 Dataset.batchDataset.repeat(如果调用它们)之前、之后?

编辑: 当使用 batch_size 参数时,并且之前没有调用 Dataset.batch(batch_size),我收到以下错误:

ValueError: The `batch_size` argument must not be specified for the given input type.
Received input: <MapDataset shapes: ((<unknown>, <unknown>, <unknown>, <unknown>), (<unknown>, <unknown>, <unknown>)), 
types: ((tf.float32, tf.float32, tf.float32, tf.float32), (tf.float32, tf.float32, tf.float32))>, 
batch_size: 1

谢谢

这里有多种方法可以满足您的需求,但我经常使用的方法是:

batch_size = 32
ds = tf.Dataset()
ds = ds.shuffle(len_ds)
train_ds = ds.take(0.8*len_ds)
train_ds = train_ds.repeat().batch(batch_size)
validation_ds = ds.skip(0.8*len_ds)
validation_ds = train_ds.repeat().batch(batch_size)
model.fit(train_ds,
          steps_per_epoch = len_train_ds // batch_size,
          validation_data = validation_ds,
          validation_steps = len_validation_ds // batch_size,
          epochs = 5)

这样你也可以在模型拟合后访问所有变量,例如,如果你想可视化验证集,你可以。这在 validation_split 中是不可能的。如果你删除 .batch(batch_size),你应该删除 // batch_sizes,但我会保留它们,因为现在发生的事情更清楚。

你总是需要提供纪元。

计算 train/validation 集的长度需要您遍历它们:

len_train_ds = 0
for i in train_ds:
  len_train_ds += 1

如果在 tf.Dataset 表格中。