在 Java 中获取数组的 k 个最小(或最大)元素的最快方法是什么?
What is the fastest way to get k smallest (or largest) elements of array in Java?
我有一个元素数组(在示例中,这些只是整数),它们使用一些自定义比较器进行比较。在这个例子中,我通过定义 i SMALLER j
当且仅当 scores[i] <= scores[j]
.
来模拟这个比较器
我有两种方法:
- 使用当前k个候选人的堆
- 使用当前 k 个候选人的数组
我按照以下方式更新上面两个结构:
- 堆:方法
PriorityQueue.poll
和PriorityQueue.offer
,
- 数组:存储候选数组中前k个候选中最差的索引
top
。如果新看到的示例比索引 top
处的元素更好,则后者被前者替换,并且 top
通过遍历数组的所有 k 个元素来更新。
但是,当我测试过哪种方法更快时,我发现这是第二种。问题是:
- 我对
PriorityQueue
的使用不是最理想的吗?
- 计算 k 个最小元素的最快方法是什么?
我对这种情况很感兴趣,当示例数量可以很大,但邻居数量相对较少(在 10 到 20 之间)时。
代码如下:
public static void main(String[] args) {
long kopica, navadno, sortiranje;
int numTries = 10000;
int numExamples = 1000;
int numNeighbours = 10;
navadno = testSimple(numExamples, numNeighbours, numTries);
kopica = testHeap(numExamples, numNeighbours, numTries);
sortiranje = testSort(numExamples, numNeighbours, numTries, false);
System.out.println(String.format("tries: %d examples: %d neighbours: %d\n time heap[ms]: %d\n time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}
public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
Random rnd = new Random(123);
long startTime = System.currentTimeMillis();
for(int iteration = 0; iteration < numberTries; iteration++){
final double[] scores = new double[numberExamples];
for(int i = 0; i < numberExamples; i++){
scores[i] = rnd.nextDouble();
}
PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
@Override
public int compare(Integer o1, Integer o2) {
return -Double.compare(scores[o1], scores[o2]);
}
});
int top;
for(int i = 0; i < numberExamples; i++){
if(i < numberNeighbours){
myHeap.offer(i);
} else{
top = myHeap.peek();
if(scores[top] > scores[i]){
myHeap.poll();
myHeap.offer(i);
}
}
}
}
long endTime = System.currentTimeMillis();
return endTime - startTime;
}
public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
Random rnd = new Random(123);
long startTime = System.currentTimeMillis();
for(int iteration = 0; iteration < numberTries; iteration++){
final double[] scores = new double[numberExamples];
for(int i = 0; i < numberExamples; i++){
scores[i] = rnd.nextDouble();
}
int[] candidates = new int[numberNeighbours];
int top = 0;
for(int i = 0; i < numberExamples; i++){
if(i < numberNeighbours){
candidates[i] = i;
if(scores[candidates[top]] < scores[candidates[i]]) top = i;
} else{
if(scores[candidates[top]] > scores[i]){
candidates[top] = i;
top = 0;
for(int j = 1; j < numberNeighbours; j++){
if(scores[candidates[top]] < scores[candidates[j]]) top = j;
}
}
}
}
}
long endTime = System.currentTimeMillis();
return endTime - startTime;
}
这会产生以下结果:
tries: 10000 examples: 1000 neighbours: 10
time heap[ms]: 393
time simple[ms]: 388
创建最快的算法绝非易事,您需要考虑很多事情。例如,k 个元素是否需要排序返回,您的研究是否需要 stable(如果两个元素相等,您需要在第一个元素之前提取或不需要)?
在这场比赛中,理论上最好的解决方案是将第 k 个最小元素保存在有序数据结构中。因为插入经常发生在这个数据结构的中间,所以平衡排序树似乎是一个最佳解决方案。
但现实与此大相径庭
可能根据原始数组的大小和 k 的值混合使用不同的数据结构是最佳解决方案:
- 如果k很小用数组保存k个最小的值
- 如果k很大使用平衡树
- 如果 k 很大并且接近数组的维数,只需对数组进行排序(如果不能创建新的排序副本),然后提取前 k 个元素。
这种算法被命名为hibryd algorithm. A famous hybrid algorithm is Tim Sort,用于java 类对集合进行排序。
注:如果可以利用多线程的强大功能,可以使用不同的算法和数据结构。
关于微基准测试的补充说明。您的绩效指标可能会受到与算法效率无关的外部因素的强烈影响。正如您在这两个函数中所做的那样,创建对象可能需要内存,而这些内存是不可用的,需要 GC 完成额外的工作。这种因素对你的结果影响很大。至少尽量减少与要调查的代码部分不密切相关的代码。以不同的顺序重复测试,在调用测试之前等待以确保没有 GC 在运行。
第一个解决方案的时间复杂度为 O(numberExamples * log numberNeighbours)
,而第二个解决方案的时间复杂度为 O(numberExamples * numberNeighbours)
,因此对于足够大的输入,它必须更慢。第二种解决方案更快,因为您测试的是小 numberNeighbours
,而 PriorityQueue 比简单数组具有更大的开销。
你用PriorityQueue最优。
更快,但不是最优,只是对数组进行排序,然后最小的元素在 k 位置。
无论如何你可能想要实现 QuickSelect 算法,如果你会巧妙地选择枢轴元素你应该有更好的性能。你可能想看看这个https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention
首先,您的基准测试方法不正确。您正在测量输入数据的创建以及算法性能,并且您没有在测量之前预热 JVM。通过 JMH:
测试代码的结果
Benchmark Mode Cnt Score Error Units
CounterBenchmark.testHeap thrpt 2 18103,296 ops/s
CounterBenchmark.testSimple thrpt 2 59490,384 ops/s
修改后的基准 pastebin.
关于两个提供的解决方案之间的 3 倍差异。在大 O 符号方面,你的第一个算法可能看起来更好,但实际上大 O 符号只告诉你算法在缩放方面有多好,它永远不会告诉你它的执行速度有多快(见这个 question 还有)。在您的情况下,缩放不是问题,因为您的 numNeighbours
被限制为 20。换句话说,大 O 符号描述了完成它需要多少次算法滴答,但它不限制持续时间一个滴答声,它只是说当输入改变时,滴答声持续时间不会改变。就报价复杂度而言,您的第二个算法肯定会赢。
What is the fastest way to compute k smallest elements?
我想出了下一个解决方案,我相信它可以让 branch prediction 完成它的工作:
@Benchmark
public void testModified(Blackhole bh) {
final double[] scores = sampleData;
int[] candidates = new int[numberNeighbours];
for (int i = 0; i < numberNeighbours; i++) {
candidates[i] = i;
}
// sorting candidates so scores[candidates[0]] is the largest
for (int i = 0; i < numberNeighbours; i++) {
for (int j = i+1; j < numberNeighbours; j++) {
if (scores[candidates[i]] < scores[candidates[j]]) {
int temp = candidates[i];
candidates[i] = candidates[j];
candidates[j] = temp;
}
}
}
// processing other scores, while keeping candidates array sorted in the descending order
for (int i = numberNeighbours; i < numberExamples; i++) {
if (scores[i] > scores[candidates[0]]) {
continue;
}
// moving all larger candidates to the left, to keep the array sorted
int j; // here the branch prediction should kick-in
for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
candidates[j - 1] = candidates[j];
}
// inserting the new item
candidates[j - 1] = i;
}
bh.consume(candidates);
}
基准测试结果(比您当前的解决方案快 2 倍):
(10 neighbours) CounterBenchmark.testModified thrpt 2 136492,151 ops/s
(20 neighbours) CounterBenchmark.testModified thrpt 2 118395,598 ops/s
其他人提到了 quickselect,但正如人们所料,该算法的复杂性忽略了它在您的案例中的优势:
@Benchmark
public void testQuickSelect(Blackhole bh) {
final int[] candidates = new int[sampleData.length];
for (int i = 0; i < candidates.length; i++) {
candidates[i] = i;
}
final int[] resultIndices = new int[numberNeighbours];
int neighboursToAdd = numberNeighbours;
int left = 0;
int right = candidates.length - 1;
while (neighboursToAdd > 0) {
int partitionIndex = partition(candidates, left, right);
int smallerItemsPartitioned = partitionIndex - left;
if (smallerItemsPartitioned <= neighboursToAdd) {
while (left < partitionIndex) {
resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
}
} else {
right = partitionIndex - 1;
}
}
bh.consume(resultIndices);
}
private int partition(int[] locations, int left, int right) {
final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
final double pivotValue = sampleData[locations[pivotIndex]];
int storeIndex = left;
for (int i = left; i <= right; i++) {
if (sampleData[locations[i]] <= pivotValue) {
final int temp = locations[storeIndex];
locations[storeIndex] = locations[i];
locations[i] = temp;
storeIndex++;
}
}
return storeIndex;
}
在这种情况下,基准测试结果非常令人沮丧:
CounterBenchmark.testQuickSelect thrpt 2 11586,761 ops/s
我有一个元素数组(在示例中,这些只是整数),它们使用一些自定义比较器进行比较。在这个例子中,我通过定义 i SMALLER j
当且仅当 scores[i] <= scores[j]
.
我有两种方法:
- 使用当前k个候选人的堆
- 使用当前 k 个候选人的数组
我按照以下方式更新上面两个结构:
- 堆:方法
PriorityQueue.poll
和PriorityQueue.offer
, - 数组:存储候选数组中前k个候选中最差的索引
top
。如果新看到的示例比索引top
处的元素更好,则后者被前者替换,并且top
通过遍历数组的所有 k 个元素来更新。
但是,当我测试过哪种方法更快时,我发现这是第二种。问题是:
- 我对
PriorityQueue
的使用不是最理想的吗? - 计算 k 个最小元素的最快方法是什么?
我对这种情况很感兴趣,当示例数量可以很大,但邻居数量相对较少(在 10 到 20 之间)时。
代码如下:
public static void main(String[] args) {
long kopica, navadno, sortiranje;
int numTries = 10000;
int numExamples = 1000;
int numNeighbours = 10;
navadno = testSimple(numExamples, numNeighbours, numTries);
kopica = testHeap(numExamples, numNeighbours, numTries);
sortiranje = testSort(numExamples, numNeighbours, numTries, false);
System.out.println(String.format("tries: %d examples: %d neighbours: %d\n time heap[ms]: %d\n time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}
public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
Random rnd = new Random(123);
long startTime = System.currentTimeMillis();
for(int iteration = 0; iteration < numberTries; iteration++){
final double[] scores = new double[numberExamples];
for(int i = 0; i < numberExamples; i++){
scores[i] = rnd.nextDouble();
}
PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
@Override
public int compare(Integer o1, Integer o2) {
return -Double.compare(scores[o1], scores[o2]);
}
});
int top;
for(int i = 0; i < numberExamples; i++){
if(i < numberNeighbours){
myHeap.offer(i);
} else{
top = myHeap.peek();
if(scores[top] > scores[i]){
myHeap.poll();
myHeap.offer(i);
}
}
}
}
long endTime = System.currentTimeMillis();
return endTime - startTime;
}
public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
Random rnd = new Random(123);
long startTime = System.currentTimeMillis();
for(int iteration = 0; iteration < numberTries; iteration++){
final double[] scores = new double[numberExamples];
for(int i = 0; i < numberExamples; i++){
scores[i] = rnd.nextDouble();
}
int[] candidates = new int[numberNeighbours];
int top = 0;
for(int i = 0; i < numberExamples; i++){
if(i < numberNeighbours){
candidates[i] = i;
if(scores[candidates[top]] < scores[candidates[i]]) top = i;
} else{
if(scores[candidates[top]] > scores[i]){
candidates[top] = i;
top = 0;
for(int j = 1; j < numberNeighbours; j++){
if(scores[candidates[top]] < scores[candidates[j]]) top = j;
}
}
}
}
}
long endTime = System.currentTimeMillis();
return endTime - startTime;
}
这会产生以下结果:
tries: 10000 examples: 1000 neighbours: 10
time heap[ms]: 393
time simple[ms]: 388
创建最快的算法绝非易事,您需要考虑很多事情。例如,k 个元素是否需要排序返回,您的研究是否需要 stable(如果两个元素相等,您需要在第一个元素之前提取或不需要)?
在这场比赛中,理论上最好的解决方案是将第 k 个最小元素保存在有序数据结构中。因为插入经常发生在这个数据结构的中间,所以平衡排序树似乎是一个最佳解决方案。
但现实与此大相径庭
可能根据原始数组的大小和 k 的值混合使用不同的数据结构是最佳解决方案:
- 如果k很小用数组保存k个最小的值
- 如果k很大使用平衡树
- 如果 k 很大并且接近数组的维数,只需对数组进行排序(如果不能创建新的排序副本),然后提取前 k 个元素。
这种算法被命名为hibryd algorithm. A famous hybrid algorithm is Tim Sort,用于java 类对集合进行排序。
注:如果可以利用多线程的强大功能,可以使用不同的算法和数据结构。
关于微基准测试的补充说明。您的绩效指标可能会受到与算法效率无关的外部因素的强烈影响。正如您在这两个函数中所做的那样,创建对象可能需要内存,而这些内存是不可用的,需要 GC 完成额外的工作。这种因素对你的结果影响很大。至少尽量减少与要调查的代码部分不密切相关的代码。以不同的顺序重复测试,在调用测试之前等待以确保没有 GC 在运行。
第一个解决方案的时间复杂度为 O(numberExamples * log numberNeighbours)
,而第二个解决方案的时间复杂度为 O(numberExamples * numberNeighbours)
,因此对于足够大的输入,它必须更慢。第二种解决方案更快,因为您测试的是小 numberNeighbours
,而 PriorityQueue 比简单数组具有更大的开销。
你用PriorityQueue最优。
更快,但不是最优,只是对数组进行排序,然后最小的元素在 k 位置。
无论如何你可能想要实现 QuickSelect 算法,如果你会巧妙地选择枢轴元素你应该有更好的性能。你可能想看看这个https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention
首先,您的基准测试方法不正确。您正在测量输入数据的创建以及算法性能,并且您没有在测量之前预热 JVM。通过 JMH:
测试代码的结果Benchmark Mode Cnt Score Error Units
CounterBenchmark.testHeap thrpt 2 18103,296 ops/s
CounterBenchmark.testSimple thrpt 2 59490,384 ops/s
修改后的基准 pastebin.
关于两个提供的解决方案之间的 3 倍差异。在大 O 符号方面,你的第一个算法可能看起来更好,但实际上大 O 符号只告诉你算法在缩放方面有多好,它永远不会告诉你它的执行速度有多快(见这个 question 还有)。在您的情况下,缩放不是问题,因为您的 numNeighbours
被限制为 20。换句话说,大 O 符号描述了完成它需要多少次算法滴答,但它不限制持续时间一个滴答声,它只是说当输入改变时,滴答声持续时间不会改变。就报价复杂度而言,您的第二个算法肯定会赢。
What is the fastest way to compute k smallest elements?
我想出了下一个解决方案,我相信它可以让 branch prediction 完成它的工作:
@Benchmark
public void testModified(Blackhole bh) {
final double[] scores = sampleData;
int[] candidates = new int[numberNeighbours];
for (int i = 0; i < numberNeighbours; i++) {
candidates[i] = i;
}
// sorting candidates so scores[candidates[0]] is the largest
for (int i = 0; i < numberNeighbours; i++) {
for (int j = i+1; j < numberNeighbours; j++) {
if (scores[candidates[i]] < scores[candidates[j]]) {
int temp = candidates[i];
candidates[i] = candidates[j];
candidates[j] = temp;
}
}
}
// processing other scores, while keeping candidates array sorted in the descending order
for (int i = numberNeighbours; i < numberExamples; i++) {
if (scores[i] > scores[candidates[0]]) {
continue;
}
// moving all larger candidates to the left, to keep the array sorted
int j; // here the branch prediction should kick-in
for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
candidates[j - 1] = candidates[j];
}
// inserting the new item
candidates[j - 1] = i;
}
bh.consume(candidates);
}
基准测试结果(比您当前的解决方案快 2 倍):
(10 neighbours) CounterBenchmark.testModified thrpt 2 136492,151 ops/s
(20 neighbours) CounterBenchmark.testModified thrpt 2 118395,598 ops/s
其他人提到了 quickselect,但正如人们所料,该算法的复杂性忽略了它在您的案例中的优势:
@Benchmark
public void testQuickSelect(Blackhole bh) {
final int[] candidates = new int[sampleData.length];
for (int i = 0; i < candidates.length; i++) {
candidates[i] = i;
}
final int[] resultIndices = new int[numberNeighbours];
int neighboursToAdd = numberNeighbours;
int left = 0;
int right = candidates.length - 1;
while (neighboursToAdd > 0) {
int partitionIndex = partition(candidates, left, right);
int smallerItemsPartitioned = partitionIndex - left;
if (smallerItemsPartitioned <= neighboursToAdd) {
while (left < partitionIndex) {
resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
}
} else {
right = partitionIndex - 1;
}
}
bh.consume(resultIndices);
}
private int partition(int[] locations, int left, int right) {
final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
final double pivotValue = sampleData[locations[pivotIndex]];
int storeIndex = left;
for (int i = left; i <= right; i++) {
if (sampleData[locations[i]] <= pivotValue) {
final int temp = locations[storeIndex];
locations[storeIndex] = locations[i];
locations[i] = temp;
storeIndex++;
}
}
return storeIndex;
}
在这种情况下,基准测试结果非常令人沮丧:
CounterBenchmark.testQuickSelect thrpt 2 11586,761 ops/s