如何在 tensorflow 上训练大型数据集 2.x

How to train large dataset on tensorflow 2.x

我有一个包含大约 200 万行和 6,000 列的大型数据集。输入的 numpy 数组 (X, y) 可以很好地保存训练数据。但是当它转到 model.fit() 时,我收到 GPU 内存不足错误。我正在使用张量流 2.2。根据其手册,model.fit_generator 已被弃用,model.fit 是首选。

有人可以概述一下使用 tensorflow v2.2 训练大型数据集的步骤吗?

最佳解决方案是使用 tf.data.Dataset(),因此您可以使用 .batch() 方法轻松地对数据进行批处理。

这里有很多教程,您可能想使用 from_tensor_slices() 直接玩 numpy 数组。

下面有两个优秀的文档可以满足您的需要。

https://www.tensorflow.org/tutorials/load_data/numpy

https://www.tensorflow.org/guide/data