如何确定 on_epoch_end 触发器

How be sure on_epoch_end triggers

我制作了一个简单的 HDF5 reader class 以避免将整个数据集加载到内存中。我使用序列 class 这样做,但我不确定 on_epoch_end() 函数是否会正确触发。

我在里面放了一张印刷品,但它从来没有出现过!所以我认为我的代码有问题:

class HDF5Generator(tf.keras.utils.Sequence):
    def __init__(self, hdf5_file, shuffle=True):
        print("GENERATED")
        self.hdf5 = h5py.File(hdf5_file, 'r')
        self.shuffle = shuffle
        self.indices = list(range(0, len(self.hdf5["samples"])))
        random.Random().shuffle(self.indices)

    def __len__(self):
        return len(self.hdf5["samples"])

    def __getitem__(self, idx):
        return self.hdf5["samples"][self.indices[idx]], self.hdf5["labels"][self.indices[idx]]

    def on_epoch_end(self):
        print("RE-SHUFFLE")
        random.Random().shuffle(self.indices)

这里是我的使用方法:

d = tf.data.Dataset.from_generator(HD5FGenerator, args=[dataset], output_signature=(...))
d = d.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()
...
model.fit(d, epochs=epochs)

在控制台中出现纪元计数器、进度条、字符串“GENERATED”但从不出现“RE-SHUFFLE”

我错过了什么?

因为这似乎是一个 TF 错误,我找到了一个解决方法来触发我的生成器 on_epoch_end()

class CallbackOnEpochEnd(Callback):
    def __init__(self, generator):
        super(CallbackOnEpochEnd, self).__init__()
        self.generator = generator

    def on_epoch_end(self, epoch, logs=None):
        self.generator.on_epoch_end()

[...]

generator = HDF5Generator()
d = tf.data.Dataset.from_generator(lambda: generator, output_signature=(tf.TensorSpec(shape=(5,20)), tf.TensorSpec(shape=(1,))))

[...]

on_epoch_end_callback = CallbackOnEpochEnd(generator)

[...]

model.fit(d, epochs=5, callbacks=[on_epoch_end_callback])

有了这个“RE-SHUFFLE”,每个纪元后都会出现在控制台上!