如何确定 on_epoch_end 触发器
How be sure on_epoch_end triggers
我制作了一个简单的 HDF5 reader class 以避免将整个数据集加载到内存中。我使用序列 class 这样做,但我不确定 on_epoch_end() 函数是否会正确触发。
我在里面放了一张印刷品,但它从来没有出现过!所以我认为我的代码有问题:
class HDF5Generator(tf.keras.utils.Sequence):
def __init__(self, hdf5_file, shuffle=True):
print("GENERATED")
self.hdf5 = h5py.File(hdf5_file, 'r')
self.shuffle = shuffle
self.indices = list(range(0, len(self.hdf5["samples"])))
random.Random().shuffle(self.indices)
def __len__(self):
return len(self.hdf5["samples"])
def __getitem__(self, idx):
return self.hdf5["samples"][self.indices[idx]], self.hdf5["labels"][self.indices[idx]]
def on_epoch_end(self):
print("RE-SHUFFLE")
random.Random().shuffle(self.indices)
这里是我的使用方法:
d = tf.data.Dataset.from_generator(HD5FGenerator, args=[dataset], output_signature=(...))
d = d.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()
...
model.fit(d, epochs=epochs)
在控制台中出现纪元计数器、进度条、字符串“GENERATED”但从不出现“RE-SHUFFLE”
我错过了什么?
因为这似乎是一个 TF 错误,我找到了一个解决方法来触发我的生成器 on_epoch_end()
。
class CallbackOnEpochEnd(Callback):
def __init__(self, generator):
super(CallbackOnEpochEnd, self).__init__()
self.generator = generator
def on_epoch_end(self, epoch, logs=None):
self.generator.on_epoch_end()
[...]
generator = HDF5Generator()
d = tf.data.Dataset.from_generator(lambda: generator, output_signature=(tf.TensorSpec(shape=(5,20)), tf.TensorSpec(shape=(1,))))
[...]
on_epoch_end_callback = CallbackOnEpochEnd(generator)
[...]
model.fit(d, epochs=5, callbacks=[on_epoch_end_callback])
有了这个“RE-SHUFFLE”,每个纪元后都会出现在控制台上!
我制作了一个简单的 HDF5 reader class 以避免将整个数据集加载到内存中。我使用序列 class 这样做,但我不确定 on_epoch_end() 函数是否会正确触发。
我在里面放了一张印刷品,但它从来没有出现过!所以我认为我的代码有问题:
class HDF5Generator(tf.keras.utils.Sequence):
def __init__(self, hdf5_file, shuffle=True):
print("GENERATED")
self.hdf5 = h5py.File(hdf5_file, 'r')
self.shuffle = shuffle
self.indices = list(range(0, len(self.hdf5["samples"])))
random.Random().shuffle(self.indices)
def __len__(self):
return len(self.hdf5["samples"])
def __getitem__(self, idx):
return self.hdf5["samples"][self.indices[idx]], self.hdf5["labels"][self.indices[idx]]
def on_epoch_end(self):
print("RE-SHUFFLE")
random.Random().shuffle(self.indices)
这里是我的使用方法:
d = tf.data.Dataset.from_generator(HD5FGenerator, args=[dataset], output_signature=(...))
d = d.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()
...
model.fit(d, epochs=epochs)
在控制台中出现纪元计数器、进度条、字符串“GENERATED”但从不出现“RE-SHUFFLE”
我错过了什么?
因为这似乎是一个 TF 错误,我找到了一个解决方法来触发我的生成器 on_epoch_end()
。
class CallbackOnEpochEnd(Callback):
def __init__(self, generator):
super(CallbackOnEpochEnd, self).__init__()
self.generator = generator
def on_epoch_end(self, epoch, logs=None):
self.generator.on_epoch_end()
[...]
generator = HDF5Generator()
d = tf.data.Dataset.from_generator(lambda: generator, output_signature=(tf.TensorSpec(shape=(5,20)), tf.TensorSpec(shape=(1,))))
[...]
on_epoch_end_callback = CallbackOnEpochEnd(generator)
[...]
model.fit(d, epochs=5, callbacks=[on_epoch_end_callback])
有了这个“RE-SHUFFLE”,每个纪元后都会出现在控制台上!