sklearn 和从头开始的不同 Kmean 结果
Different Kmean results by sklearn and from scratch
我尝试比较 sklearn
包和从头开始的 kmean 聚类结果。 scratch代码如下:
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
import numpy as np
colors = 10 * ["g", "r", "c", "b", "k"]
class K_Means:
def __init__(self, k=3, tol=0.001, max_iter=300):
self.k = k
self.tol = tol
self.max_iter = max_iter
def fit(self, data):
self.centroids = {}
for i in range(self.k):
self.centroids[i] = data[i]
for i in range(self.max_iter):
self.classifications = {}
for i in range(self.k):
self.classifications[i] = []
for featureset in data:
distances = [np.linalg.norm(featureset - self.centroids[centroid]) for centroid in self.centroids]
classification = distances.index(min(distances))
self.classifications[classification].append(featureset)
prev_centroids = dict(self.centroids)
for classification in self.classifications:
self.centroids[classification] = np.average(self.classifications[classification], axis=0)
optimized = True
for c in self.centroids:
original_centroid = prev_centroids[c]
current_centroid = self.centroids[c]
if np.sum((current_centroid - original_centroid) / original_centroid * 100.0) > self.tol:
print(np.sum((current_centroid - original_centroid) / original_centroid * 100.0))
optimized = False
if optimized:
break
def predict(self, data):
distances = [np.linalg.norm(data - self.centroids[centroid]) for centroid in self.centroids]
classification = distances.index(min(distances))
return classification
kmeans = K_Means()
kmeans.fit(reduced_data)
for centroid in kmeans.centroids:
plt.scatter(kmeans.centroids[centroid][0], kmeans.centroids[centroid][1],
marker="x", color="b", s=169, linewidths=3, zorder=10)
for classification in kmeans.classifications:
color = colors[classification]
for featureset in kmeans.classifications[classification]:
plt.scatter(featureset[0], featureset[1], marker="o", color=color)
plt.show()
但是,由于收敛的质心不同,结果也不同。
来自 sklearn 的散点图:
同时,上面代码的散点图:
我想知道 scratch 代码中有什么错误。
K-means 高度依赖于初始化条件,即方法的起点。 scikit-learn可以根据数据进行智能初始化。如果您仔细阅读文档,您可能可以配置 scikit-learn 的版本以匹配您自己的版本。另外,请尝试查看 source code 以获取更多线索。
我尝试比较 sklearn
包和从头开始的 kmean 聚类结果。 scratch代码如下:
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
import numpy as np
colors = 10 * ["g", "r", "c", "b", "k"]
class K_Means:
def __init__(self, k=3, tol=0.001, max_iter=300):
self.k = k
self.tol = tol
self.max_iter = max_iter
def fit(self, data):
self.centroids = {}
for i in range(self.k):
self.centroids[i] = data[i]
for i in range(self.max_iter):
self.classifications = {}
for i in range(self.k):
self.classifications[i] = []
for featureset in data:
distances = [np.linalg.norm(featureset - self.centroids[centroid]) for centroid in self.centroids]
classification = distances.index(min(distances))
self.classifications[classification].append(featureset)
prev_centroids = dict(self.centroids)
for classification in self.classifications:
self.centroids[classification] = np.average(self.classifications[classification], axis=0)
optimized = True
for c in self.centroids:
original_centroid = prev_centroids[c]
current_centroid = self.centroids[c]
if np.sum((current_centroid - original_centroid) / original_centroid * 100.0) > self.tol:
print(np.sum((current_centroid - original_centroid) / original_centroid * 100.0))
optimized = False
if optimized:
break
def predict(self, data):
distances = [np.linalg.norm(data - self.centroids[centroid]) for centroid in self.centroids]
classification = distances.index(min(distances))
return classification
kmeans = K_Means()
kmeans.fit(reduced_data)
for centroid in kmeans.centroids:
plt.scatter(kmeans.centroids[centroid][0], kmeans.centroids[centroid][1],
marker="x", color="b", s=169, linewidths=3, zorder=10)
for classification in kmeans.classifications:
color = colors[classification]
for featureset in kmeans.classifications[classification]:
plt.scatter(featureset[0], featureset[1], marker="o", color=color)
plt.show()
但是,由于收敛的质心不同,结果也不同。
来自 sklearn 的散点图:
同时,上面代码的散点图:
K-means 高度依赖于初始化条件,即方法的起点。 scikit-learn可以根据数据进行智能初始化。如果您仔细阅读文档,您可能可以配置 scikit-learn 的版本以匹配您自己的版本。另外,请尝试查看 source code 以获取更多线索。