Return SPARK 中另一个 RDD 的最大 N 值的 RDD
Return RDD of largest N values from another RDD in SPARK
我正在尝试根据键值将元组的 RDD 过滤为 return 最大的 N 个元组。我需要 return 格式作为 RDD。
所以 RDD:
[(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')]
过滤最大的 3 个键应该 return RDD:
[(6,'p'), (12,'e'), (49,'y')]
执行 sortByKey()
然后 take(N)
return 的值并不会导致 RDD,所以这不会工作。
我可以 return 所有的键,对它们进行排序,找到第 N 个最大值,然后过滤 RDD 以查找大于该值的键值,但这似乎非常低效。
最好的方法是什么?
和RDD
一个快速但不是特别有效的解决方案是遵循 sortByKey
使用 zipWithIndex
和 filter
:
n = 3
rdd = sc.parallelize([(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')])
rdd.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
如果 n 与 RDD 大小相比相对较小,则更有效的方法是避免完全排序:
import heapq
def key(kv):
return kv[0]
top_per_partition = rdd.mapPartitions(lambda iter: heapq.nlargest(n, iter, key))
top_per_partition.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
如果键比值小得多并且最终输出的顺序无关紧要,那么filter
方法可以很好地工作:
keys = rdd.keys()
identity = lambda x: x
offset = (keys
.mapPartitions(lambda iter: heapq.nlargest(n, iter))
.sortBy(identity)
.zipWithIndex()
.filter(lambda xi: xi[1] < n)
.keys()
.max())
rdd.filter(lambda kv: kv[0] <= offset)
此外,如果出现平局,它也不会保留精确的 n 值。
和DataFrames
你可以 orderBy
和 limit
:
from pyspark.sql.functions import col
rdd.toDF().orderBy(col("_1").desc()).limit(n)
一种更省力的方法,因为您只想将 take(N)
结果转换为新的 RDD。
sc.parallelize(yourSortedRdd.take(Nth))
我正在尝试根据键值将元组的 RDD 过滤为 return 最大的 N 个元组。我需要 return 格式作为 RDD。
所以 RDD:
[(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')]
过滤最大的 3 个键应该 return RDD:
[(6,'p'), (12,'e'), (49,'y')]
执行 sortByKey()
然后 take(N)
return 的值并不会导致 RDD,所以这不会工作。
我可以 return 所有的键,对它们进行排序,找到第 N 个最大值,然后过滤 RDD 以查找大于该值的键值,但这似乎非常低效。
最好的方法是什么?
和RDD
一个快速但不是特别有效的解决方案是遵循 sortByKey
使用 zipWithIndex
和 filter
:
n = 3
rdd = sc.parallelize([(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')])
rdd.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
如果 n 与 RDD 大小相比相对较小,则更有效的方法是避免完全排序:
import heapq
def key(kv):
return kv[0]
top_per_partition = rdd.mapPartitions(lambda iter: heapq.nlargest(n, iter, key))
top_per_partition.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
如果键比值小得多并且最终输出的顺序无关紧要,那么filter
方法可以很好地工作:
keys = rdd.keys()
identity = lambda x: x
offset = (keys
.mapPartitions(lambda iter: heapq.nlargest(n, iter))
.sortBy(identity)
.zipWithIndex()
.filter(lambda xi: xi[1] < n)
.keys()
.max())
rdd.filter(lambda kv: kv[0] <= offset)
此外,如果出现平局,它也不会保留精确的 n 值。
和DataFrames
你可以 orderBy
和 limit
:
from pyspark.sql.functions import col
rdd.toDF().orderBy(col("_1").desc()).limit(n)
一种更省力的方法,因为您只想将 take(N)
结果转换为新的 RDD。
sc.parallelize(yourSortedRdd.take(Nth))