了解 scikit-learn GridSearchCV - 参数调整和平均性能指标

Understanding scikit-learn GridSearchCV - param tuning and averaging performance metrics

我想了解 scikit-learn 中的 GridSearchCV 究竟是如何在机器学习中实现训练验证测试原则的。正如您在下面的代码中看到的,我理解它的作用如下:

  1. 将 'dataset' 分成 75% 和 25%,其中 75% 用于参数调优,25%​​ 用于保留测试集(第 1 行)
  2. 初始化一些参数进行搜索(第 3 到 6 行)
  3. 在 75% 的数据集上拟合模型,但将此数据集分成 5 份,即每次在 60% 的数据上训练,在另外 15% 的数据上测试,重复 5 次(第 8 行) - 10). 这里有我的第一个和第二个问题,见下文。
  4. 采用性能最好的模型和参数,测试 holdout 数据(第 11-13 行)

问题 1:关于参数 space,步骤 3 中到底发生了什么? GridSearchCV 是否在五个 运行s(5 倍)中的每一个上尝试每个参数组合,因此总共给出 10 运行s? (即,'optmizers'、'init' 和 'batches' 中的单个参数与 'epoches']

中的 2 个配对

问题 2:第 'cross_val_score' 行打印的分数是多少?这是 5 个 运行 中每个数据的单次折叠上面 10 个 运行 的平均值吗? (即整个数据集的五个 15% 的平均值)?

问题3:假设第5行现在只有1个参数值,这次GridSearchCV真的没有搜索任何参数,因为每个参数只有1个值,这样正确吗?

问题 4:在问题 3 中解释的情况下,如果我们对 GridSearchCV 运行s 的 5 倍和heldout 运行,这给了我们整个数据集的平均性能分数——这与 6 折交叉验证实验(即没有网格搜索)非常相似,除了 6 折的大小不完全相等。或者这不是?

非常感谢您的任何回复!

X_train_data, X_test_data, y_train, y_test = \
         train_test_split(dataset[:,0:8], dataset[:,8],
                          test_size=0.25,
                          random_state=42) #line 1

model = KerasClassifier(build_fn=create_model, verbose=0)
optimizers = ['adam']  #line 3
init = ['uniform']
epochs = [10,20] #line 5
batches = [5]   # line 6
param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, init=init)
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=5)  # line 8
grid_result = grid.fit(X_train_data, y_train) 
cross_val_score(grid.best_estimator_, X_train_data, y_train, cv=5).mean() #line 10
best_param_ann = grid.best_params_      #line 11
best_estimator = grid.best_estimator_
heldout_predictions = best_estimator.predict(X_test_data)   #line 13

问题 1: 正如您所说,您的数据集将分为 5 部分。 将尝试每个参数(在您的情况下为 2)。对于每个参数,模型将在 5 折中的 4 折上进行训练。剩下的一个将用作测试。所以你是对的,在你的例子中,你将训练一个模型 10 次。

问题二: 'cross_val_score' 是 5 次测试的平均值(准确率、损失或其他)。这样做是为了避免仅仅因为测试集真的很简单就得到了好的结果。

问题三: 是的。如果您只有一组参数来尝试进行网格搜索,那将毫无意义

问题四: 我没有完全理解你的问题。通常,您在训练集上使用网格搜索。这允许您将测试集保留为验证集。如果没有交叉验证,您可能会找到一个完美的设置来最大化您的测试集的结果,并且您会过度拟合您的测试集。通过交叉验证,您可以使用 fine-tuning 参数随心所欲地玩,因为您不使用验证集来设置它。

在你的代码中,没有太多的 CV 需求,因为你没有太多参数可以使用,但如果你开始添加正则化,你可以尝试 10+,在这种情况下,CV 是必需的.

希望对你有帮助,