带有张量流的 Keras 抛出 ResourceExhaustedError
Keras with tensorflow throws ResourceExhaustedError
出于研究目的,我正在训练一个神经网络,该网络根据纪元的奇偶性以不同方式更新其权重:
1) 如果epoch是偶数,用反向传播改变NN的权重
2) 如果epoch是奇数,只用update_weights_with_custom_function()
更新模型,因此冻结网络。
这是实现此代码的简化部分(注意 epochs=1
):
for epoch in range(nb_epoch):
if epoch % 2 == 0:
model.trainable = True # Unfreeze the model
else:
model.trainable = False # Freeze the model
model.compile(optimizer=optim, loss=gaussian_loss, metrics=['accuracy'])
hist = model.fit(X_train, Y_train,
batch_size=batch_size,
epochs=1,
shuffle=True,
verbose=1,
callbacks=[tbCallBack, csv_epochs, early_stop],
validation_data=(X_val, Y_val))
if epoch % 2 == 1:
update_weights_with_custom_function()
问题: 在几个 epoch 之后,keras 抛出一个 ResourceExhaustedError
但只针对 tensorflow, 而不是 theano。似乎循环 compile()
是在创建模型而不释放它们。
所以,我该怎么办?我知道 K.clear_session()
会释放内存,但它需要保存模型并重新加载它 (see),这给我带来了一些问题,因为 load_model()
在我的情况下无法立即使用。
我也对其他方式持开放态度来实现我想要实现的目标(即根据纪元的奇偶性冻结 NN 模型)。
总结: 带有 tensorflow 后端的 keras 抛出 ResourceExhaustedError
因为我正在循环 compile()
。
正如 Marcin Możejko 指出的那样,使用 eval()
正是我想要实现的目标。
我添加了一个自定义回调(灵感来自 here),它避免了 compile()
的循环
问题现已解决,即使 tensorflow 问题并未直接解决。
出于研究目的,我正在训练一个神经网络,该网络根据纪元的奇偶性以不同方式更新其权重:
1) 如果epoch是偶数,用反向传播改变NN的权重
2) 如果epoch是奇数,只用update_weights_with_custom_function()
更新模型,因此冻结网络。
这是实现此代码的简化部分(注意 epochs=1
):
for epoch in range(nb_epoch):
if epoch % 2 == 0:
model.trainable = True # Unfreeze the model
else:
model.trainable = False # Freeze the model
model.compile(optimizer=optim, loss=gaussian_loss, metrics=['accuracy'])
hist = model.fit(X_train, Y_train,
batch_size=batch_size,
epochs=1,
shuffle=True,
verbose=1,
callbacks=[tbCallBack, csv_epochs, early_stop],
validation_data=(X_val, Y_val))
if epoch % 2 == 1:
update_weights_with_custom_function()
问题: 在几个 epoch 之后,keras 抛出一个 ResourceExhaustedError
但只针对 tensorflow, 而不是 theano。似乎循环 compile()
是在创建模型而不释放它们。
所以,我该怎么办?我知道 K.clear_session()
会释放内存,但它需要保存模型并重新加载它 (see),这给我带来了一些问题,因为 load_model()
在我的情况下无法立即使用。
我也对其他方式持开放态度来实现我想要实现的目标(即根据纪元的奇偶性冻结 NN 模型)。
总结: 带有 tensorflow 后端的 keras 抛出 ResourceExhaustedError
因为我正在循环 compile()
。
正如 Marcin Możejko 指出的那样,使用 eval()
正是我想要实现的目标。
我添加了一个自定义回调(灵感来自 here),它避免了 compile()
问题现已解决,即使 tensorflow 问题并未直接解决。