对 tf.estimator.Estimator 中的参数感到困惑

Confusion about the parameter in tf.estimator.Estimator

所以我正在使用

tf.estimator.Estimator(
    model_fn, model_dir=None, config=None, params=None, warm_start_from=None
)

我对参数 params 感到困惑。

我知道它是 dict,根据一些示例代码,我假设 params 有点像:

params = {"batch_size":128,
          "hidden_layer": 3
}

但是根据官方页面,params 是将传递给 model_fn 的超参数字典。键是参数的名称,值是基本 python 类型 (offical page)。所以值应该是 python 类型像 int64, float64?

请给我一个明确的解释。非常感谢您的帮助

再往下docs

The params argument contains hyperparameters. It is passed to the model_fn, if the model_fn has a parameter named "params", and to the input functions in the same manner. Estimator only passes params along, it does not inspect it. The structure of params is therefore entirely up to the developer.

换句话说,认为合适的才是合适的。如果您的模型加载权重,它可能是权重文件的字符串路径:weights_path = "model.h5"。辍学率浮动,介于 0.1. 之间。像这样:

def model_fn(params):
    ...
    x = Dense(params['units'])(x)
    x = Dropout(params['dropout'])(x)
    ...
    model.load_weights(params['weights_path'])
    return model

TF 检查 model_fn 是否有 params 参数 here,并相应地传递它。 model_fn 也可以有任何其他参数。