Tensorflow 2.0 数据集批处理无法正常工作

Tensorflow 2.0 dataset batching not working properly

Tensorflow 2.0 数据集 api 的批处理没有像我预期的那样工作。

我做了一个这样的数据集。

self.train_dataset = tf.data.Dataset.from_generator(generator=train_generator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([6]), tf.TensorShape([])))

这会产生 DatasetV1Adapter 形状:((6,), ()),类型:(tf.float32, tf.float32), 对于这个数据集,我应用了 tf.data.Dataset.

中的批处理函数
self.train_dataset.batch(1024)

生成 DatasetV1Adapter 形状:((None, 6), (None,)),类型:(tf.float32, tf.float32),并且更改批量大小不会帮助。

根据批次的官方描述,

The components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

我认为此功能的工作方式是制作 [batch, 6], [batch,] 但效果不佳。

我最初使用 pytorch,最近开始使用 TF 2.0,需要一些有关正确批处理的帮助。提前致谢。

你可以通过设置得到想要的结果,

train_dataset = train_dataset.batch(2, drop_remainder=True)
默认为

drop_remainder=False。在这种情况下,第一个维度 必须 None 因为在数据集的末尾(很可能)会有一个包含 < batch_size 元素的批次,因为样本数不能被 batch_size.

整除