使用 sample_weights 和 fit_generator()

Using sample_weights with fit_generator()

在自回归连续问题中,当零点过多时,可以将这种情况视为零膨胀问题(即 ZIB)。换句话说,我们不想拟合 f(x),而是拟合 g(x)*f(x),其中 f(x) 是我们想要近似的函数,即 yg(x) 是一个函数,它根据值是零还是非零输出 0 到 1 之间的值。

目前,我有两个模型。一个模型给我 g(x),另一个模型适合 g(x)*f(x)

第一个模型给了我一组权重。这就是我需要你帮助的地方。我可以将 sample_weights 参数与 model.fit() 一起使用。当我处理大量数据时,我需要使用 model.fit_generator()。但是,fit_generator() 没有参数 sample_weights

是否有解决方法可以在 fit_generator() 中使用 sample_weights?否则,我如何适应 g(x)*f(x) 知道我已经为 g(x) 训练了模型?

您可以提供样本权重作为生成器返回的元组的第三个元素。来自 fit_generator 上的 Keras 文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

  • a tuple (inputs, targets)
  • a tuple (inputs, targets, sample_weights).

更新: 这是一个生成器的粗略草图,其中 returns 输入样本和目标以及从模型 g(x) 获得的样本权重:

def gen(args):
    while True:
        for i in range(num_batches):
            # get the i-th batch data
            inputs = ...
            targets = ...
            
            # get the sample weights
            weights = g.predict(inputs)
            
            yield inputs, targets, weights
            
            
model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)