Keras 警告:Epoch 包含超过 `samples_per_epoch` 个样本
Keras warning: Epoch comprised more than `samples_per_epoch` samples
我有大约 6200 张训练图像,我想使用 keras.preprocessing.image.ImageDataGenerator
class 的 flow(X, y)
方法按以下方式扩充小数据集:
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow(X_train , y_train)
validation_generator = test_datagen.flow(X_val , y_val)
history = model.fit_generator(
train_generator,
samples_per_epoch=1920,
nb_epoch=10,
verbose=1,
validation_data=validation_generator,
nb_val_samples=800)
其中 X_train
/y_train
包含大约 6000 个训练图像和标签,X_val
/y_val
验证数据和模型是增强的 VGG16 模型。
文档说
flow(X, y): Takes numpy data & label arrays, and generates batches of augmented/normalized data. Yields batches indefinitely, in an infinite loop.
对于具有 10 个时期、每个时期 1920 个样本和 batch_size 32 的训练设置,我得到以下训练轨迹:
1920/1920 [==============================] - 3525s - loss: 3.9101 - val_loss: 0.0269
Epoch 2/10
1920/1920 [==============================] - 3609s - loss: 1.0245 - val_loss: 0.0229
Epoch 3/10
1920/1920 [==============================] - 3201s - loss: 0.7620 - val_loss: 0.0161
Epoch 4/10
1916/1920 [============================>.] - ETA: 4s - loss: 0.5978 C:\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\training.py:1537: UserWarning: Epoch comprised more than `samples_per_epoch` samples, which might affect learning results. Set `samples_per_epoch` correctly to avoid this warning.
warnings.warn('Epoch comprised more than
为什么生成器没有像文档中所说的那样生成无限批次?
所以基本上 KerasImageGenerator
class 实现中存在一个小错误。好的是——除了这个烦人的警告之外,没有发生任何错误。所以澄清一下:
flow
和flow_from_directory
实际上都是在无限循环中产生样本。您可以通过测试以下代码轻松检查(警告 - 它会冻结您的 Python
):
for x, y in train_generator:
x = None
您提到的警告是在fit_generator
方法中提出的。它基本上检查在一个时期内处理的样本数量是否小于或等于 samples_per_epoch
。在您的情况下 - samples_per_epoch
可被 batch_size
整除 - 如果 Keras 的实施是正确的 - 永远不应发出此警告......但是..
.. 是啊,为什么会发出这个警告?这有点棘手。如果您更深入地了解生成器的实现,您会注意到生成器以下列方式获取批次:如果您有 10 个示例和 batch_size = 3
,则:
- 它会先打乱这10个例子的顺序,
- 那么首先需要 3 个打乱的示例,然后是接下来的 3 个,依此类推,
- 在第 3 批之后 - 当只剩下 1 个样本时 - 它将 return 一批..只有一个样本。
不要问我为什么 - 这就是生成器的实现方式。好的是它几乎不影响训练过程。
所以 - 总而言之 - 您可以忽略此警告,也可以使传递给生成器的样本数可以被 batch_size
整除。我知道这很麻烦,我希望它会在下一个版本中得到修复。
我有大约 6200 张训练图像,我想使用 keras.preprocessing.image.ImageDataGenerator
class 的 flow(X, y)
方法按以下方式扩充小数据集:
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow(X_train , y_train)
validation_generator = test_datagen.flow(X_val , y_val)
history = model.fit_generator(
train_generator,
samples_per_epoch=1920,
nb_epoch=10,
verbose=1,
validation_data=validation_generator,
nb_val_samples=800)
其中 X_train
/y_train
包含大约 6000 个训练图像和标签,X_val
/y_val
验证数据和模型是增强的 VGG16 模型。
文档说
flow(X, y): Takes numpy data & label arrays, and generates batches of augmented/normalized data. Yields batches indefinitely, in an infinite loop.
对于具有 10 个时期、每个时期 1920 个样本和 batch_size 32 的训练设置,我得到以下训练轨迹:
1920/1920 [==============================] - 3525s - loss: 3.9101 - val_loss: 0.0269
Epoch 2/10
1920/1920 [==============================] - 3609s - loss: 1.0245 - val_loss: 0.0229
Epoch 3/10
1920/1920 [==============================] - 3201s - loss: 0.7620 - val_loss: 0.0161
Epoch 4/10
1916/1920 [============================>.] - ETA: 4s - loss: 0.5978 C:\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\training.py:1537: UserWarning: Epoch comprised more than `samples_per_epoch` samples, which might affect learning results. Set `samples_per_epoch` correctly to avoid this warning.
warnings.warn('Epoch comprised more than
为什么生成器没有像文档中所说的那样生成无限批次?
所以基本上 KerasImageGenerator
class 实现中存在一个小错误。好的是——除了这个烦人的警告之外,没有发生任何错误。所以澄清一下:
flow
和flow_from_directory
实际上都是在无限循环中产生样本。您可以通过测试以下代码轻松检查(警告 - 它会冻结您的Python
):for x, y in train_generator: x = None
您提到的警告是在
fit_generator
方法中提出的。它基本上检查在一个时期内处理的样本数量是否小于或等于samples_per_epoch
。在您的情况下 -samples_per_epoch
可被batch_size
整除 - 如果 Keras 的实施是正确的 - 永远不应发出此警告......但是.... 是啊,为什么会发出这个警告?这有点棘手。如果您更深入地了解生成器的实现,您会注意到生成器以下列方式获取批次:如果您有 10 个示例和
batch_size = 3
,则:- 它会先打乱这10个例子的顺序,
- 那么首先需要 3 个打乱的示例,然后是接下来的 3 个,依此类推,
- 在第 3 批之后 - 当只剩下 1 个样本时 - 它将 return 一批..只有一个样本。
不要问我为什么 - 这就是生成器的实现方式。好的是它几乎不影响训练过程。
所以 - 总而言之 - 您可以忽略此警告,也可以使传递给生成器的样本数可以被 batch_size
整除。我知道这很麻烦,我希望它会在下一个版本中得到修复。