提高 kNN 分类器的性能(速度)
Improving performance (speed) of kNN classifier
作为一项作业,我必须创建自己的 kNN 分类器,而不使用 for 循环。我设法使用 scipy.spatial.KDTree
找到测试集中每个向量的最近邻居,然后我使用 scipy.stats.mode
到 return 预测列表 类.但是,当集合的大小非常大时,这会花费很长时间。例如,我创建了以下受 this page
启发的示例
import numpy as np
from sklearn.model_selection import train_test_split
from scipy import spatial
from scipy.stats import mode
def predict(X_test):
X = Y_train[tree.query(X_test, k=k)[1]]
Y = mode(X, axis=-1)[0].T[0]
return Y
def load_data():
x1 = 1.5 * np.random.randn(100) + 1
y1 = 1.5 * np.random.randn(100) + 2
x2 = 1.5 * np.random.randn(100) + 3
y2 = 1.5 * np.random.randn(100) + 4
X = np.vstack((np.hstack((x1,x2)),np.hstack((y1,y2)))).T
y = 1.0*np.hstack((np.zeros(100), np.ones(100)))
return X, y
if __name__ == '__main__':
X, y = load_data()
X_train, X_test, Y_train, Y_test = train_test_split(X, y)
k = 7
Z = predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
由于 X = Y_train[tree.query(X_test, k=k)[1]]
部分,这需要很长时间(40-60 秒!)。有什么方法可以提高 this 具体实现的速度,还是我应该想别的办法呢?例如,sklearn
的实现只需要 0.4 秒,与我的实现相比快得离谱。
不得不阅读你的代码几次,但后来我看到你使用的是 KDTree
而不是 cKDTree
。后者是在 Cython 中实现的(而不是普通的 python 和 numpy)并且应该给你一个不错的加速。
作为一项作业,我必须创建自己的 kNN 分类器,而不使用 for 循环。我设法使用 scipy.spatial.KDTree
找到测试集中每个向量的最近邻居,然后我使用 scipy.stats.mode
到 return 预测列表 类.但是,当集合的大小非常大时,这会花费很长时间。例如,我创建了以下受 this page
import numpy as np
from sklearn.model_selection import train_test_split
from scipy import spatial
from scipy.stats import mode
def predict(X_test):
X = Y_train[tree.query(X_test, k=k)[1]]
Y = mode(X, axis=-1)[0].T[0]
return Y
def load_data():
x1 = 1.5 * np.random.randn(100) + 1
y1 = 1.5 * np.random.randn(100) + 2
x2 = 1.5 * np.random.randn(100) + 3
y2 = 1.5 * np.random.randn(100) + 4
X = np.vstack((np.hstack((x1,x2)),np.hstack((y1,y2)))).T
y = 1.0*np.hstack((np.zeros(100), np.ones(100)))
return X, y
if __name__ == '__main__':
X, y = load_data()
X_train, X_test, Y_train, Y_test = train_test_split(X, y)
k = 7
Z = predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
由于 X = Y_train[tree.query(X_test, k=k)[1]]
部分,这需要很长时间(40-60 秒!)。有什么方法可以提高 this 具体实现的速度,还是我应该想别的办法呢?例如,sklearn
的实现只需要 0.4 秒,与我的实现相比快得离谱。
不得不阅读你的代码几次,但后来我看到你使用的是 KDTree
而不是 cKDTree
。后者是在 Cython 中实现的(而不是普通的 python 和 numpy)并且应该给你一个不错的加速。