了解快速求幂函数

Understanding fast exponentiation function

我无法理解此功能为何有效?有人可以逐步解释它在做什么吗?我知道这个想法是 a^n 等于 (a^(n/2))^2 如果 n 是偶数或 a(a^((n-1)/2))^2 如果 n 是奇数,但是这个函数是怎么做到的呢?

double pow(double a, int n) {
    double ret = 1;
    while(n) {
        if(n%2 == 1) ret *= a;
        a *= a; n /= 2;
    }
    return ret;
}

这是我的 Python 递归代码,它在我看来更具可读性和易懂性(我知道在 Python 中创建递归函数不是个好主意,但我选择了 Python 因为其简单的语法来演示这个想法)。

def pow(n, e):
    if e == 0:
        return 1

    if e % 2 == 1:
        return n * pow(n, e - 1)

    # this step makes the algorithm to run in O(lg n) time
    tmp = pow(n, e / 2)

    return tmp * tmp

我再强调一次,tmp = pow(n, e / 2)是降低时间复杂度的那一行。

该算法不是将 e 乘以数字 n,而是重复使用了一些先前计算的结果。例如 2^8 将计算为 2^4 * 2^4。这里 2^4 将只计算一次,并且将以这种方式跳过一半的迭代。 2^4 等相同

我试图以某种方式更直观地解释它,但没有深入研究此优化背后的理论。如果您想更深入地了解它以及它在位级别上的工作原理,这里有一个很好的tutorial

本程序中使用的等式如下:

  1. invariant of the loop是:(在循环的每一步),a^n * ret是结果。事实上,一开始ret1,而在循环结束时n == 0,所以a^0 * ret就是结果,而由于a^0 == 1ret 是预期结果。
  2. 如果n是奇数,(即n%2 == 1),则存在b≥0使得n=b*2+1。在这种情况下,我们使用以下等式:a^(b*2+1)=(a^(b*2))*a。所以 ret 乘以 a
  3. 在接下来的语句中,使用了下面的等式:a^(b*2) = (a^2)^b,使得a自乘,n除以2,最终保持不变量.

注意在循环内部,n /= 2中使用了整数除法,所以两种情况下结果总是bn奇数,即n=b*2+1, 或 n 为偶数, 即 n=b*2).

最后,请注意,正如@chux 在评论中指出的那样,该函数无法正确管理 n.

的负值

我将从一些更明显的代码开始:

double pow(double a, int n) {
    int k = 0, m = 1, n2 = n;
    double pow_k = 1.0, pow_m = a;
    assert (n2 * m + k == n);

    while (n2 != 0) {
        if (n2 % 2 != 0) { k += m; pow_k *= pow_m; n2 -= 1; }
        assert (n2 * m + k == n); assert (n2 % 2 == 0);
        m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
        assert (n2 * m + k == n);
    }

    return pow_k;
}

在循环中的每个点,pow_k = a^k 和 pow_m = a^m。 n2 * m + k == n 始终为真。当 n2 == n, m == 1, k == 0 时初始为真。

在循环中的第一个 if 语句之前,n2 为偶数,因此断言保持为真且 n2 保持为偶数。或者 n2 是奇数。在这种情况下,n2 减 1,n2 * m 减 m; k 增加 m,n2 * m + k 不变。 n2 变得均匀。

然后 m 加倍,n2 恰好减半,因为 n2 是偶数,n2 * m + k 再次保持不变。

由于n2每次迭代都要除以2,最终n2变为0,所以循环结束。 n2 == 0 的断言表示 0 * m + k == n 或 k == n,因此 pow_k = a^k = a^n。因此返回的结果是 a^n。

现在我们省略了 k、m 和断言,这不会改变计算:

double pow(double a, int n) {
    int n2 = n;
    double pow_k = 1.0, pow_m = a;

    while (n2 != 0) {
        if (n2 % 2 != 0) { pow_k *= pow_m; n2 -= 1; }
        m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
    }

    return pow_k;
}

当 n2 为奇数时,我们可以删除 n2 -= 1,因为除以 2 后没有区别。由于没有使用 n,我们可以只使用 n 而不是 n2:

double pow(double a, int n) {
    double pow_k = 1.0, pow_m = a;

    while (n != 0) {
        if (n % 2 != 0) pow_k *= pow_m;
        pow_m = pow_m * pow_m; n /= 2;
    }

    return pow_k;
}

现在我们把pow_k改成ret,pow_m改成a,把n % 2 != 0改成n % 2 == 1,得到原码:

double pow(double a, int n) {
    double ret = 1.0;

    while (n != 0) {
        if (n % 2 == 1) ret *= a;
        a *= a; n /= 2;
    }

    return ret;
}