Tensorflow 数据 API:重复()

Tensorflow Data API: repeat()

以下代码摘自“使用 scimitar-learn、Keras 和 tensorflow 进行机器学习实践”。 我理解以下代码中的所有内容,除了在第二行链接 .repeat(repeat) 函数。

我知道 repeat 是重复数据集元素(即,在本例中为文件路径),如果参数设置为 None 或留空,重复将永远持续下去,直到使用它的函数决定何时停止。

正如您在下面的代码中看到的,作者将 repeat() 参数设置为 None;

1 - 基本上我想知道作者为什么决定这样做?

2 - 或者是因为代码试图模拟数据集不适合内存的情况,如果是这种情况那么在真实情况下我们应该避免 repeat(),我是对的吗?


def csv_reader_dataset(filepaths, repeat=1, n_readers=5,
                       n_read_threads=None, shuffle_buffer_size=10000,
                       n_parse_threads=5, batch_size=32):
    dataset = tf.data.Dataset.list_files(filepaths, seed = 42).repeat(repeat)
    dataset = dataset.interleave(
        lambda filepath: tf.data.TextLineDataset(filepath).skip(1),
        cycle_length = n_readers, num_parallel_calls = n_read_threads)
    
    dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(preprocess, num_parallel_calls = n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset.prefetch(1)

train_set = csv_reader_dataset(train_filepaths, repeat = None)
valid_set = csv_reader_dataset(valid_filepaths)
test_set = csv_reader_dataset(test_filepaths)


keras.backend.clear_session()
np.random.seed(42)
tf.random.set_seed(42)

model = keras.models.Sequential([
    keras.layers.InputLayer(input_shape = X_train.shape[-1: ]),
    keras.layers.Dense(30, activation = 'relu'),
    keras.layers.Dense(1)
])

m_loss = keras.losses.mean_squared_error
m_optimizer = keras.optimizers.SGD(lr = 1e-3)

batch_size = 32
model.compile(loss = m_loss, optimizer = m_optimizer, metrics = ['accuracy'])
model.fit(train_set, steps_per_epoch = len(X_train) // batch_size, epochs = 10, validation_data = valid_set)

对于你的问题,我认为:

  • tf.data API 不会轻易导致内存不足,因为它加载给定文件路径或 tfrecrods(压缩模式)的数据。因此,repeat() 与内存无关;相反,它用于数据转换。
  • 设置steps_per_epoch为#时我必须使用repeat(#)。假设你的 batch_num = 32,并且 steps_per_epoch = 100//32 = 3 -> 每个时期需要 3 * 32 = 96 个样本 但是你的数据只有 80 个样本。然后,我必须使用 data.repeat(2) 总共有 160 个样本,其中 repeat_1 中的 80 个样本和 repeat_2 中的前 16 个样本将在 1 个时期内使用。这是为了防止错误 Input 运行 out of data.

我在书的作者 git 回购中有另一个相同问题的副本。 问题已澄清;这是由于 Keras 2.0 中的一个错误。

阅读更多内容:https://github.com/ageron/handson-ml2/issues/407