Tensorflow 2.0 数据集和数据加载器

Tensorflow 2.0 dataset and dataloader

我是pytorch用户,用惯了pytorch中的data.dataset和data.dataloaderapi。我正在尝试使用 tensorflow 2.0 构建相同的模型,我想知道是否有一个 api 与 pytorch 中的这些 api 类似。

如果没有这样的api,你们谁能告诉我人们通常如何在tensorflow中实现数据加载部分?我使用过 tensorflow 1,但从未使用过数据集 api。我以前硬编码过。我希望有类似覆盖 getitem 的东西,只有索引作为输入。

在此先致谢。

我不熟悉 Pytorch,但 Tensorflow 实现了 Keras API,它具有序列 class 即:

Base object for fitting to a sequence of data, such as a dataset

https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence

此 class 包含索引的 getitem。

当使用tf.data API时,您通常也会使用<a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map" rel="noreferrer">map</a>函数。

在 PyTorch 中,您的 __getItem__ 调用基本上是从 __init__ 中给出的数据结构中获取一个元素,并在必要时对其进行转换。

在 TF2.0 中,您可以通过初始化一个 <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset" rel="noreferrer">Dataset</a> using one of the Dataset.from_... functions (see <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator" rel="noreferrer">from_generator</a>, <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices" rel="noreferrer">from_tensor_slices</a>, <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors" rel="noreferrer">from_tensors</a>);这本质上是 PyTorch Dataset__init__ 部分。然后,您可以调用 map 来执行您在 __getItem__.

中将进行的逐元素操作

Tensorflow 数据集是非常奇特的迭代器,因此根据设计,您不会使用索引访问它们的元素,而是通过遍历它们。

tf.data 上的 guide 非常有用,提供了各种各样的示例。