在 numpy 数组上并行化区域计算

parallelize zonal computation on numpy array

我尝试计算 numpy 数组上同一区域(相同值)的所有单元格的模式。我给你下面的代码示例。在这个例子中,顺序方法工作正常,但多进程方法什么都不做。我没有发现我的错误。

有人看到我的错误了吗?

我想并行计算,因为我的实际数组是一个 10k * 10k 数组,有 1M 个区域。

import numpy as np
import scipy.stats as ss
import multiprocessing as mp

def zone_mode(i, a, b, output):
    to_extract = np.where(a == i)
    val = b[to_extract]
    output[to_extract] = ss.mode(val)[0][0]
    return output

def zone_mode0(i, a, b):
    to_extract = np.where(a == i)
    val = b[to_extract]
    output = ss.mode(val)[0][0]
    return output

np.random.seed(1)

zone = np.array([[1, 1, 1, 2, 3],
                 [1, 1, 2, 2, 3],
                 [4, 2, 2, 3, 3],
                 [4, 4, 5, 5, 3],
                 [4, 6, 6, 5, 5],
                 [6, 6, 6, 5, 5]])
values = np.random.randint(8, size=zone.shape)

output = np.zeros_like(zone).astype(np.float)

for i in np.unique(zone):
    output = zone_mode(i, zone, values, output)

# for multiprocessing    
zone0 = zone - 1

pool = mp.Pool(mp.cpu_count() - 1)
results = [pool.apply(zone_mode0, args=(u, zone0, values)) for u in np.unique(zone0)]
pool.close()
output = results[zone0]

对于数组中的正整数 - zonevalues,我们可以使用 np.bincount。基本思想是我们将 zonevalues 视为二维网格上的行和列。因此,可以将它们映射到它们的线性索引等效数。这些将用作 np.bincount 的分箱求和的分箱。他们的 argmax ID 将是模式编号。它们被映射回 zone-grid 并索引到 zone.

因此,解决方案是 -

m = zone.max()+1
n = values.max()+1
ids = zone*n + values
c = np.bincount(ids.ravel(),minlength=m*n).reshape(-1,n).argmax(1)
out = c[zone]

对于稀疏数据(在输入数组中很好地分布整数),我们可以查看稀疏矩阵以获得 argmax ID c。因此,使用 SciPy 的稀疏矩阵 -

from scipy.sparse import coo_matrix

data = np.ones(zone.size,dtype=int)
r,c = zone.ravel(),values.ravel()
c = coo_matrix((data,(r,c))).argmax(1).A1

轻微性能。提升,指定形状 -

c = coo_matrix((data,(r,c)),shape=(m,n)).argmax(1).A1

求解泛型 values

我们将利用pandas.factorize,像这样-

import pandas as pd

ids,unq = pd.factorize(values.flat)
v = ids.reshape(values.shape)
# .. same steps as earlier with bincount, using v in place of values
out = unq[c[zone]]

请注意,对于并列案例,它会选择随机元素 values。如果要选择第一个,请使用 pd.factorize(values.flat, sort=True).