不规则地播放视图 numpy

broadcast views irregularly numpy

假设我想要一个大小为 (n,m) 的 numpy 数组,其中 n 非常大,但有很多重复项,即。 0:n1 相同,n1:n2 相同等(与n2%n1!=0,即不规则间隔)。有没有办法在查看整个数组的同时为每个重复项仅存储一组值?

示例:

unique_values = np.array([[1,1,1], [2,2,2] ,[3,3,3]]) #these are the values i want to store in memory
index_mapping = np.array([0,0,1,1,1,2,2]) # a mapping between index of array above, with array below

unique_values_view = np.array([[1,1,1],[1,1,1],[2,2,2],[2,2,2],[2,2,2], [3,3,3],[3,3,3]]) #this is how I want the view to look like for broadcasting reasons

我计划将 array(view) 乘以其他一些大小为 (1,m) 的数组,并取该乘积的点积:

other_array1 = np.arange(unique_values.shape[1]).reshape(1,-1) # (1,m)
other_array2 = 2*np.ones((unique_values.shape[1],1)) # (m,1)
output = np.dot(unique_values_view * other_array1, other_array2).squeeze()

输出是一个长度为 n.

的一维数组

根据你的例子,你可以简单地将索引映射分解到最后:

output2 = np.dot(unique_values * other_array1, other_array2).squeeze()[index_mapping]

assert (output == output2).all()

您的表达式有两个重要的优化:

  • 最后做索引
  • 先将 other_array1other_array2 相乘,然后再与 unique_values

让我们应用这些优化:

>>> output_pp = (unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]

# check for correctness
>>> (output == output_pp).all()
True

# and compare it to @Yakym Pirozhenko's approach
>>> from timeit import timeit
>>> print("yp:", timeit("np.dot(unique_values * other_array1, other_array2).squeeze()[index_mapping]", globals=globals()))
yp: 3.9105667411349714
>>> print("pp:", timeit("(unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]", globals=globals()))
pp: 2.2684884609188884

如果我们观察两件事,就很容易发现这些优化:

(1) 如果 A 是一个 mxn-矩阵并且 b 是一个 n- 向量那么

A * b == A @ diag(b)
A.T * b[:, None] == diag(b) @ A.T

(2) 如果 A 是一个 mxn-矩阵并且 I 是一个 k- 整数向量 range(m) 然后

A[I] == onehot(I) @ A

onehot 可以定义为

def onehot(I, m, dtype=int):
    out = np.zeros((I.size, m), dtype=dtype)
    out[np.arange(I.size), I] = 1
    return out

使用这些事实并缩写 uvimoa1oa2 我们可以写成

uv[im] * oa1 @ oa2 == onehot(im) @ uv @ diag(oa1) @ oa2

上述优化现在只是为这些矩阵乘法选择最佳顺序的问题,即

onehot(im) @ (uv @ (diag(oa1) @ oa2))

在此基础上反向使用 (1) 和 (2) 我们从 post.

的开头获得优化表达式