在 Keras 中设置 LearningRateScheduler

Setting up a LearningRateScheduler in Keras

我正在 Keras 中设置一个学习率调度器,使用历史损失作为 self.model.optimizer.lr 的更新器,但是 self.model.optimizer.lr 上的值没有插入到 SGD 优化器和优化器中正在使用默认学习率。代码是:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD
from keras.wrappers.scikit_learn import KerasRegressor
from sklearn.preprocessing import StandardScaler

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.model.optimizer.lr=3
    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        self.model.optimizer.lr=lr-10000*self.losses[-1]

def base_model():
    model=Sequential()
    model.add(Dense(4, input_dim=2, init='uniform'))
    model.add(Dense(1, init='uniform'))
    sgd = SGD(decay=2e-5, momentum=0.9, nesterov=True)


model.compile(loss='mean_squared_error',optimizer=sgd,metrics['mean_absolute_error'])
    return model

history=LossHistory()

estimator = KerasRegressor(build_fn=base_model,nb_epoch=10,batch_size=16,verbose=2,callbacks=[history])

estimator.fit(X_train,y_train,callbacks=[history])

res = estimator.predict(X_test)

使用 Keras 作为连续变量的回归器一切正常,但我想通过更新优化器学习率来获得更小的导数。

学习率是计算设备上的一个变量,例如一个 GPU,如果你使用 GPU 计算。这意味着您必须使用 K.set_valueKkeras.backend。例如:

import keras.backend as K
K.set_value(opt.lr, 0.01)

或者在你的例子中

K.set_value(self.model.optimizer.lr, lr-10000*self.losses[-1])

谢谢,我找到了替代解决方案,因为我没有使用 GPU:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD
from keras.callbacks import LearningRateScheduler

sd=[]
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = [1,1]

    def on_epoch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        sd.append(step_decay(len(self.losses)))
        print('lr:', step_decay(len(self.losses)))

epochs = 50
learning_rate = 0.1
decay_rate = 5e-6
momentum = 0.9

model=Sequential()
model.add(Dense(4, input_dim=2, init='uniform'))
model.add(Dense(1, init='uniform'))
sgd = SGD(lr=learning_rate,momentum=momentum, decay=decay_rate, nesterov=False)
model.compile(loss='mean_squared_error',optimizer=sgd,metrics=['mean_absolute_error'])

def step_decay(losses):
    if float(2*np.sqrt(np.array(history.losses[-1])))<0.3:
        lrate=0.01*1/(1+0.1*len(history.losses))
        momentum=0.8
        decay_rate=2e-6
        return lrate
    else:
        lrate=0.1
        return lrate
history=LossHistory()
lrate=LearningRateScheduler(step_decay)

model.fit(X_train,y_train,nb_epoch=epochs,callbacks=[history,lrate],verbose=2)
model.predict(X_test)

输出为(lr为学习率):

Epoch 41/50
lr: 0.0018867924528301887
0s - loss: 0.0126 - mean_absolute_error: 0.0785
Epoch 42/50
lr: 0.0018518518518518517
0s - loss: 0.0125 - mean_absolute_error: 0.0780
Epoch 43/50
lr: 0.0018181818181818182
0s - loss: 0.0125 - mean_absolute_error: 0.0775
Epoch 44/50
lr: 0.0017857142857142857
0s - loss: 0.0126 - mean_absolute_error: 0.0785
Epoch 45/50
lr: 0.0017543859649122807
0s - loss: 0.0126 - mean_absolute_error: 0.0773

这就是学习率在各个时期发生的变化:

keras.callbacks.LearningRateScheduler(schedule, verbose=0)

在新的 Keras API 中,您可以使用更通用的 schedule 函数版本,它带有两个参数 epochlr

From docs:

schedule: a function that takes an epoch index as input (integer, indexed from 0) and current learning rate and returns a new learning rate as output (float).

From sources:

    try:  # new API
        lr = self.schedule(epoch, lr)
    except TypeError:  # old API for backward compatibility
        lr = self.schedule(epoch)
    if not isinstance(lr, (float, np.float32, np.float64)):
        raise ValueError('The output of the "schedule" function '
                         'should be float.')

所以你的函数可以是:

def lr_scheduler(epoch, lr):
    decay_rate = 0.1
    decay_step = 90
    if epoch % decay_step == 0 and epoch:
        return lr * decay_rate
    return lr

callbacks = [
    keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=1)
]

model.fit(callbacks=callbacks, ... )