如何让 tf.data.Dataset.map 函数在第一个时期只执行一次?

How to make tf.data.Dataset.map function executed only once in first epoch?

我尝试使用 tf.data.Dataset 对数据集进行一些转换。

我发现转换在每个 epoch 中都执行过。 map函数有没有可能在第一个epoch执行?

您可以只使用不同的数据集。这在自定义训练循环中很容易。就这样:

def transformation(inputs, labels):
    tf.print('With transformation!')
    return inputs, labels

def no_transformation(inputs, labels):
    tf.print('No transformation!')
    return inputs, labels

data_with_transform = data.take(4).map(transformation).batch(4)
data_no_transform = data.take(4).map(no_transformation).batch(4)

然后:

if epoch < 1:
    ds = data_with_transform
else:
    ds = data_no_transform

for X_train, y_train in ds:
    train_step(X_train, y_train)

完整示例:

import tensorflow_datasets as tfds
import tensorflow as tf

data, info = tfds.load('iris', split='train', as_supervised=True,
                       with_info=True)

def transformation(inputs, labels):
    tf.print('With transformation!')
    return inputs, labels

def no_transformation(inputs, labels):
    tf.print('No transformation!')
    return inputs, labels

data_with_transform = data.take(4).map(transformation).batch(4)
data_no_transform = data.take(4).map(no_transformation).batch(4)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(8, activation='relu'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(info.features['label'].num_classes)
])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


def main(epochs=5):

    for epoch in range(epochs):

        train_loss.reset_states()
        train_acc.reset_states()

        if epoch < 1:
            ds = data_with_transform
        else:
            ds = data_no_transform

        for X_train, y_train in ds:
            train_step(X_train, y_train)

if __name__ == '__main__':
    main()
With transformation!
With transformation!
With transformation!
With transformation!

No transformation!
No transformation!
No transformation!
No transformation!

No transformation!
No transformation!
No transformation!
No transformation!

No transformation!
No transformation!
No transformation!
No transformation!

No transformation!
No transformation!
No transformation!
No transformation!