在 Pyspark 的 RDD 分区中拆分数组

Splitting arrays in RDD partitions in Pyspark

我有一个单独的 3D 数值数据文件,我从中读取数据块(因为读取数据块比单个索引更快)。例如说 'file' 中有一个 MxNx30 数组,我会创建一个这样的 RDD:

def read(ind):
    f = customFileOpener(file)
    return f['data'][:,:,ind[0]:ind[-1]+1]

indices = [[0,9],[10,19],[20,29]]
rdd = sc.parallelize(indices,3).map(lambda v:read(v))
rdd.count()

所以这 3 个分区中的每一个都有一个大小为 MxNx10 的 numpy.ndarray 元素。

现在,我想拆分每个元素,因此在每个分区中,我有 10 个元素,每个元素都是一个 MxN 数组。为此,我尝试使用 flatMap(),但出现 'NoneType object is not iterable':

错误
def splitArr(arr):
    Nmid = arr.shape[-1]
    out = []
    for i in range(0,Nmid):
         out.append(arr[...,i])
    return out

rdd2 = rdd.flatMap(lambda v: splitArr(v))
rdd2.count()

正确的做法是什么?关键点是 (a) 我需要从文件中分块读取数据,以及 (b) 拆分数据,使元素大小为 MxN(最好保持分区结构)。

据我了解你的描述,这样的事情应该可以解决问题:

rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2)))

或者如果您更喜欢单独的功能:

def splitArr(arr):
    for x in np.rollaxis(arr, 2):
        yield x

rdd.flatMap(splitArr)