计算 BigInteger[] 的乘积

Calculating the product of BigInteger[]

上下文:我正在尝试使用 Java 中的 BigInteger class 计算非常大的 n 的阶乘(对于 n>100,000),到目前为止,这是我正在做的事情:

根据我在互联网上所做的研究,这比简单地将所有 k 乘以 n 快得渐进。但是我注意到,我的实现中最慢的部分是我乘以所有素数幂的部分。我的问题是:

代码:

public static BigInteger product(BigInteger[] numbers) {
    if (numbers.length == 0)
        throw new ArithmeticException("There is nothing to multiply!");
    if (numbers.length == 1)
        return numbers[0];
    if (numbers.length == 2)
        return numbers[0].multiply(numbers[1]);

    BigInteger[] part1 = new BigInteger[numbers.length / 2];
    BigInteger[] part2 = new BigInteger[numbers.length - numbers.length / 2];
    System.arraycopy(numbers, 0, part1, 0, numbers.length / 2);
    System.arraycopy(numbers, numbers.length / 2, part2, 0, numbers.length - numbers.length / 2);

    return product(part1).multiply(product(part2));
}

我提出另一个想法,pow算法很快,你可以用指数计算所有素数,像这样:

11! -> {2^10, 3^5, 5^2, 7^1, 11^1}

您可以计算所有素数的幂,然后使用分治法将它们全部相乘。 实施:

private static BigInteger divideAndConquer(List<BigInteger> primesExp, int min, int max){
    BigInteger result = BigInteger.ONE;
    if (max - min == 1){
        result = primesExp.get(min);
    } else if (min < max){
        int middle = (max + min)/2;
        result = divideAndConquer(primesExp, min, middle).multiply(divideAndConquer(primesExp, middle, max));
    }
    return result;
}

public static BigInteger factorial(int n) {
    // compute pairs: prime, exp
    List<Integer> primes = new ArrayList<>();
    Map<Integer, Integer> primeTimes = new LinkedHashMap<>();
    for (int i = 2; i <= n; i++) {
        int sqrt = Math.round((float) Math.sqrt(i));
        int value = i;
        Iterator<Integer> it = primes.iterator();
        int prime = 0;
        while (it.hasNext() && prime <= sqrt && value != 0) {
            prime = it.next();
            int times = 0;
            while (value % prime == 0) {
                value /= prime;
                times++;
            }
            if (times > 0) {
                primeTimes.put(prime, times + primeTimes.get(prime));
            }
        }
        if (value > 1) {
            Integer times = primeTimes.get(value);
            if (times == null) {
                times = 0;
                primes.add(value);
            }
            primeTimes.put(value, times + 1);
        }
    }
    // compute primes power:
    List<BigInteger> primePows = new ArrayList<>(primes.size());
    for (Entry<Integer,Integer> e: primeTimes.entrySet()) {
        primePows.add(new BigInteger(String.valueOf(e.getKey())).pow(e.getValue()));
    }
    // it multiply all of them:
    return divideAndConquer(primePows, 0, primePows.size());
}

提高性能的一种方法是执行以下操作:

  1. 对需要相乘的数字数组进行排序
  2. 创建两个新列表:ab
  3. 对于输入列表中需要相乘的每个数字,它很可能出现不止一次。假设数字 v_i 出现 n_i 次。然后将 v_i 添加到 a n_i / 2 次(向下舍入)。如果 n_i 是奇数,将 v_i 也添加到 b 一次。
  4. 要计算结果,请执行:
BigInteger A = product(a);
BigInteger B = prudoct(b);
return a.multiply(a).multiply(b);

要了解它是如何工作的,假设您的输入数组是 [2, 2, 2, 2, 3, 3, 3]。所以,有四个 2 和三个 3。数组 ab 将相应地成为

a = [2, 2, 3]
b = [3]

然后你将递归调用计算这些的乘积。请注意,我们将要相乘的数字数量从 7 减少到 4,几乎减少了两倍。这里的诀窍是,对于出现多次的数,我们可以只计算其中一半的乘积,然后计算它的 2 次方。与如何在 O(log n) 时间内计算一个数的幂非常相似。

可能是最快的方法:

Sequence.java

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public final class Sequence {

    private final List<BigInteger> elements;

    private Sequence(List<BigInteger> elements) {
        this.elements = elements;
    }

    public List<BigInteger> getElements() {
        return elements;
    }

    public int size() {
        return elements.size();
    }

    public Sequence subSequence(int startInclusive, int endExclusive) {
        return subSequence(startInclusive, endExclusive, false);
    }

    public Sequence subSequence(int startInclusive, int endExclusive, boolean sync) {
        return Sequence.of(elements.subList(startInclusive, endExclusive), sync);
    }

    public void addLast(BigInteger element) {
        elements.add(element);
    }

    public BigInteger removeLast() {
        return elements.remove(size() - 1);
    }

    public BigInteger sum() {
        return sum(false);
    }

    public BigInteger sum(boolean parallel) {
        return parallel
                ? elements.parallelStream().reduce(BigInteger.ZERO, BigInteger::add)
                : elements.stream().reduce(BigInteger.ZERO, BigInteger::add);
    }

    public BigInteger product() {
        return product(false);
    }

    public BigInteger product(boolean parallel) {
        return parallel
                ? elements.parallelStream().reduce(BigInteger.ONE, BigInteger::multiply)
                : elements.stream().reduce(BigInteger.ONE, BigInteger::multiply);
    }

    public static Sequence range(int startInclusive, int endExclusive) {
        return range(startInclusive, endExclusive, false);
    }

    public static Sequence range(int startInclusive, int endExclusive, boolean sync) {
        if (startInclusive > endExclusive) {
            throw new IllegalArgumentException();
        }
        final List<BigInteger> elements = sync ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>();
        for (; startInclusive < endExclusive; startInclusive++) {
            elements.add(BigInteger.valueOf(startInclusive));
        }
        return new Sequence(elements);
    }

    public static Sequence of(List<BigInteger> elements) {
        return of(elements, false);
    }

    public static Sequence of(List<BigInteger> elements, boolean sync) {
        return new Sequence(sync ? Collections.synchronizedList(elements) : elements);
    }

    public static Sequence empty() {
        return empty(false);
    }

    public static Sequence empty(boolean sync) {
        return of(new ArrayList<>(), sync);
    }

}

FactorialCalculator.java

import java.math.BigInteger;
import java.util.LinkedList;
import java.util.List;

public final class FactorialCalculator {

    private static final int CHUNK_SIZE = Runtime.getRuntime().availableProcessors();

    public static BigInteger fact(int n) {
        return fact(n, false);
    }

    public static BigInteger fact(int n, boolean parallel) {
        if (n < 0) {
            throw new IllegalArgumentException();
        }
        if (n <= 1) {
            return BigInteger.ONE;
        }
        Sequence sequence = Sequence.range(1, n + 1);
        if (!parallel) {
            return sequence.product();
        }
        sequence = parallelCalculate(splitSequence(sequence, CHUNK_SIZE * 2));
        while (sequence.size() > CHUNK_SIZE) {
            sequence = parallelCalculate(splitSequence(sequence, CHUNK_SIZE));
        }
        return sequence.product(true);
    }

    private static List<Sequence> splitSequence(Sequence sequence, int chunkSize) {
        final int size = sequence.size();
        final List<Sequence> subSequences = new LinkedList<>();
        int index = 0, targetIndex;
        while (index < size) {
            targetIndex = Math.min(index + chunkSize, size);
            subSequences.add(sequence.subSequence(index, targetIndex, true));
            index = targetIndex;
        }
        return subSequences;
    }

    private static Sequence parallelCalculate(List<Sequence> sequences) {
        final Sequence result = Sequence.empty(true);
        sequences.parallelStream().map(s -> s.product(true)).forEach(result::addLast);
        return result;
    }

}

测试:

public static void main(String[] args) {
    // warm up
    for (int i = 0; i < 100; i++) {
        FactorialCalculator.fact(10000);
    }
    int n = 1000000;
    long start = System.currentTimeMillis();
    FactorialCalculator.fact(n, true);
    long end = System.currentTimeMillis();
    System.out.printf("Execution time = %d ms", end - start);
}

结果:

Execution time = 3066 ms
  • OS : Win 10 专业版 64 位
  • CPU:英特尔酷睿 i7-4700HQ @ 2.40GHz 2.40GHz