scikit learn 的 kmeans 实现中 predict() 方法的用途是什么?

What is the use of predict() method in kmeans implementation of scikit learn?

谁能解释一下在 scikit learn 的 kmeans 实现中 predict() 方法的用途是什么? official documentation 声明其用途为:

Predict the closest cluster each sample in X belongs to.

但是我也可以通过在 fit_transform() 方法上训练模型来获得输入集 X 的每个样本的聚类 number/label。那么predict()方法有什么用呢?是否应该为看不见的数据指出最近的集群?如果是,那么如果你执行SVD等降维措施,你如何处理一个新的数据点?

这是一个 similar question 但我仍然认为它没有真正帮助。

what is the use of predict() method? Is it supposed to point out closest cluster for the unseen data?

是的,完全正确。

then how do you handle a new data point if you perform dimensionality reduction measure such as SVD?

在将未见数据传递给 .predict() 之前,您对未见数据应用相同的降维方法。这是一个典型的工作流程:

# prerequisites:
#    x_train: training data
#    x_test: "unseen" testing data
#    km: initialized `KMeans()` instance
#    dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)    

# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)  

# ...

# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)

# ...

其实,fit_transformfit_predict等方法都是为了方便。 y = km.fit_predict(x) 等同于 y = km.fit(x).predict(x).

我认为如果我们把拟合部分写成下面这样更容易看出是怎么回事:

# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)

km.fit(x_dr)
y = km.predict(x_dr)

除了对 .fit() 的调用外,模型在拟合过程中同样使用,并且具有看不见的数据。

总结:

  • .fit()的目的是用数据训练模型。
  • .predict().transform() 的目的是将经过训练的模型应用于数据。
  • 如果你想拟合模型并在训练时将其应用于相同的数据,为了方便,有.fit_predict().fit_transform()
  • 链接多个模型(例如降维和聚类)时,在拟合和测试期间以相同的顺序应用它们。