带有矢量化嵌套矩阵的 numpy knn
numpy knn with vectorizing nested matrices
我正在尝试使用 numpy 及其矢量化实现 KNN 算法。
import numpy as np
A = np.array([[11,12,13], [21,22,23], [31,32,33]])
B = np.array([[41,42,43], [51,52,53], [61,62,63]])
C = np.array([[71,72,73], [81,82,83], [91,92,93]])
X = np.array([A, B, C])
f = lambda a,b: (a-b)**2
np.sqrt(np.vectorize(f)(A, X).reshape(3,9).sum(axis=1))
# array([ 0., 90., 180.])
如果我这样使用它,我会得到所有向量到 A 的距离的 3x1 矩阵。
要获得所有向量之间的所有距离,我需要一个结果矩阵如下:
array([[ 0., 90., 180.],
[ 90., 0., 90.],
[ 180., 90., 0.]])
如何改进操作以获得我想要的结果?
In [843]: np.sqrt(f(X[:,None],X).reshape(3,3,9).sum(axis=-1))
Out[843]:
array([[ 0., 90., 180.],
[ 90., 0., 90.],
[ 180., 90., 0.]])
您不需要 vectorize
。 f
直接使用数组。
但是看着
In [841]: f(X[:,None],X).reshape(3,3,9)
Out[841]:
array([[[ 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 900, 900, 900, 900, 900, 900, 900, 900, 900],
[3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600]],
[[ 900, 900, 900, 900, 900, 900, 900, 900, 900],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 900, 900, 900, 900, 900, 900, 900, 900, 900]],
[[3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600],
[ 900, 900, 900, 900, 900, 900, 900, 900, 900],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=int32)
让我怀疑您正在计算超出需要的值。但我还没有尝试着看大局。
我正在尝试使用 numpy 及其矢量化实现 KNN 算法。
import numpy as np
A = np.array([[11,12,13], [21,22,23], [31,32,33]])
B = np.array([[41,42,43], [51,52,53], [61,62,63]])
C = np.array([[71,72,73], [81,82,83], [91,92,93]])
X = np.array([A, B, C])
f = lambda a,b: (a-b)**2
np.sqrt(np.vectorize(f)(A, X).reshape(3,9).sum(axis=1))
# array([ 0., 90., 180.])
如果我这样使用它,我会得到所有向量到 A 的距离的 3x1 矩阵。
要获得所有向量之间的所有距离,我需要一个结果矩阵如下:
array([[ 0., 90., 180.],
[ 90., 0., 90.],
[ 180., 90., 0.]])
如何改进操作以获得我想要的结果?
In [843]: np.sqrt(f(X[:,None],X).reshape(3,3,9).sum(axis=-1))
Out[843]:
array([[ 0., 90., 180.],
[ 90., 0., 90.],
[ 180., 90., 0.]])
您不需要 vectorize
。 f
直接使用数组。
但是看着
In [841]: f(X[:,None],X).reshape(3,3,9)
Out[841]:
array([[[ 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 900, 900, 900, 900, 900, 900, 900, 900, 900],
[3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600]],
[[ 900, 900, 900, 900, 900, 900, 900, 900, 900],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 900, 900, 900, 900, 900, 900, 900, 900, 900]],
[[3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600, 3600],
[ 900, 900, 900, 900, 900, 900, 900, 900, 900],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=int32)
让我怀疑您正在计算超出需要的值。但我还没有尝试着看大局。