train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() 有什么作用?
What does train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() do?
我正在关注 Tensorflow 的 timeseries/LSTM 教程,并且很难理解这行代码的作用,因为它没有真正解释:
train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
我试图查看不同模块的作用,但我无法理解完整的命令及其对数据集的影响。
这是整个教程:
Click
这是一个基于tensorflow.data
API 的输入管道定义。
分解:
(train_data # some tf.data.Dataset, likely in the form of tuples (x, y)
.cache() # caches the dataset in memory (avoids having to reapply preprocessing transformations to the input)
.shuffle(BUFFER_SIZE) # shuffle the samples to have always a random order of samples fed to the network
.batch(BATCH_SIZE) # batch samples in chunks of size BATCH_SIZE (except the last one, that may be smaller)
.repeat()) # repeat forever, meaning the dataset will keep producing batches and never terminate running out of data.
备注:
- 因为重复是在洗牌之后进行的,批次总是不同的,即使是跨时期也是如此。
- 由于
cache()
,数据集的第二次迭代将从内存中的缓存中加载数据,而不是之前的管道步骤。如果数据预处理很复杂,这可以节省你一些时间(但是,对于大数据集,这可能会占用你大量的内存)
BUFFER_SIZE
是随机缓冲区中的项目数。该函数填充缓冲区,然后从中随机采样。适当的洗牌需要足够大的缓冲区,但它与内存消耗保持平衡。重新洗牌在每个时期自动发生。
注意:这是一个管道定义,因此您要指定哪些操作在管道中,而不是实际上运行它们!这些操作实际上是在您调用 next(iter(dataset))
时发生的,而不是之前。
我正在关注 Tensorflow 的 timeseries/LSTM 教程,并且很难理解这行代码的作用,因为它没有真正解释:
train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
我试图查看不同模块的作用,但我无法理解完整的命令及其对数据集的影响。 这是整个教程: Click
这是一个基于tensorflow.data
API 的输入管道定义。
分解:
(train_data # some tf.data.Dataset, likely in the form of tuples (x, y)
.cache() # caches the dataset in memory (avoids having to reapply preprocessing transformations to the input)
.shuffle(BUFFER_SIZE) # shuffle the samples to have always a random order of samples fed to the network
.batch(BATCH_SIZE) # batch samples in chunks of size BATCH_SIZE (except the last one, that may be smaller)
.repeat()) # repeat forever, meaning the dataset will keep producing batches and never terminate running out of data.
备注:
- 因为重复是在洗牌之后进行的,批次总是不同的,即使是跨时期也是如此。
- 由于
cache()
,数据集的第二次迭代将从内存中的缓存中加载数据,而不是之前的管道步骤。如果数据预处理很复杂,这可以节省你一些时间(但是,对于大数据集,这可能会占用你大量的内存) BUFFER_SIZE
是随机缓冲区中的项目数。该函数填充缓冲区,然后从中随机采样。适当的洗牌需要足够大的缓冲区,但它与内存消耗保持平衡。重新洗牌在每个时期自动发生。
注意:这是一个管道定义,因此您要指定哪些操作在管道中,而不是实际上运行它们!这些操作实际上是在您调用 next(iter(dataset))
时发生的,而不是之前。