快速整数平方根近似

fast integer square root approximation

我目前正在寻找一个非常快速的整数平方根近似值,其中 floor(sqrt(x)) <= veryFastIntegerSquareRoot(x) <= x

平方根例程用于计算素数,如果仅检查小于或等于 sqrt(x) 的值作为 x 的约数,计算速度会大大加快。

我目前拥有的是 this function from Wikipedia,稍作调整以使用 64 位整数。

因为我没有其他函数可以比较(或者更准确地说,这个函数对我的目的来说太精确了,而且它可能需要更多的时间,而不是比实际结果更高。)

Loopfree/jumpfree(好吧:几乎 ;-)牛顿-拉夫森:

/* static will allow inlining */
static unsigned usqrt4(unsigned val) {
    unsigned a, b;

    if (val < 2) return val; /* avoid div/0 */

    a = 1255;       /* starting point is relatively unimportant */

    b = val / a; a = (a+b) /2;
    b = val / a; a = (a+b) /2;
    b = val / a; a = (a+b) /2;
    b = val / a; a = (a+b) /2;

    return a;
}

对于 64 位整数,您将需要更多步骤(我的猜测:6)

这个版本可以更快,因为 DIV 很慢而且数量少 (Val<20k) 此版本将错误减少到 5% 以下。 在 ARM M0 上测试(没有 DIV 硬件加速)

static uint32_t usqrt4_6(uint32_t val) {
    uint32_t a, b;

    if (val < 2) return val; /* avoid div/0 */
    a = 1255;       /* starting point is relatively unimportant */
    b = val / a; a = (a + b)>>1;
    b = val / a; a = (a + b)>>1;
    b = val / a; a = (a + b)>>1;
    b = val / a; a = (a + b)>>1;
    if (val < 20000) {  
        b = val / a; a = (a + b)>>1;    // < 17% error Max
        b = val / a; a = (a + b)>>1;    // < 5%  error Max
    }
    return a;
}

在现代 PC 硬件上,使用浮点算法计算 n 的平方根可能比任何类型的快速整数数学运算都更快,但对于所有性能问题,需要仔细的基准测试。

但是请注意,可能根本不需要计算平方根:您可以改为对循环索引求平方并在平方超过 n 的值时停止。无论如何,占主导地位的操作是循环体中的除法:

#define PBITS32  ((1<<2) | (1<<3) | (1<<5) | (1<<7) | (1<<11) | (1<<13) | \
                  (1UL<<17) | (1UL<<19) | (1UL<<23) | (1UL<<29) | (1UL<<31))

int isprime(unsigned int n) {
    if (n < 32)
        return (PBITS32 >> n) & 1;
    if ((n & 1) == 0)
        return 0;
    for (unsigned int p = 3; p * p <= n; p += 2) {
        if (n % p == 0)
            return 0;
    }
    return 1;
}

计算 floor(sqrt(x)) 准确

这是我的解决方案,它基于 bit-guessing approach proposed on Wikipedia。不幸的是,维基百科上提供的 pseudo-code 有一些错误,所以我不得不做一些调整:

unsigned char bit_width(unsigned long long x) {
    return x == 0 ? 1 : 64 - __builtin_clzll(x);
}

// implementation for all unsigned integer types
unsigned sqrt(const unsigned n) {
    unsigned char shift = bit_width(n);
    shift += shift & 1; // round up to next multiple of 2

    unsigned result = 0;

    do {
        shift -= 2;
        result <<= 1; // leftshift the result to make the next guess
        result |= 1;  // guess that the next bit is 1
        result ^= result * result > (n >> shift); // revert if guess too high
    } while (shift != 0);

    return result;
}

bit_width 可以在常数时间内计算,循环最多迭代 ceil(bit_width / 2) 次。因此,即使对于 64 位整数,这也最多是基本算术和按位运算的 32 次迭代。

与迄今为止提出的所有其他答案不同,这实际上为您提供了最佳近似值,即 floor(sqrt(x))。对于任何 x2,这将 return x 恰好。

使用 log2(x)

进行猜测

如果这对您来说仍然太慢,您可以仅根据二进制对数进行猜测。基本思想是我们可以使用 2x/2 计算任意数字 2xsqrtx/2 可能有余数,所以我们不能总是准确地计算它,但我们可以计算一个上限和下限。

例如:

  1. 我们得到 25
  2. 计算floor(log_2(25)) = 4
  3. 计算ceil(log_2(25)) = 5
  4. 下限:pow(2, floor(4 / 2)) = 4
  5. 上限:pow(2, ceil(5 / 2)) = 8

事实上,实际 sqrt(25) = 5。我们找到 sqrt(16) >= 4sqrt(32) <= 8。这意味着:

4 <= sqrt(16) <= sqrt(25) <= sqrt(32) <= 8
            4 <= sqrt(25) <= 8

这就是我们如何实现这些猜测,我们称之为 sqrt_losqrt_hi

// this function computes a lower bound
unsigned sqrt_lo(const unsigned n) noexcept
{
    unsigned log2floor = bit_width(n) - 1;
    return (unsigned) (n != 0) << (log2floor >> 1);
}

// this function computes an upper bound
unsigned sqrt_hi(const unsigned n)
{
    bool isnt_pow2 = ((n - 1) & n) != 0; // check if n is a power of 2
    unsigned log2ceil = bit_width(n) - 1 + isnt_pow2;
    log2ceil += log2ceil & 1; // round up to multiple of 2
    return (unsigned) (n != 0) << (log2ceil >> 1);
}

对于这两个函数,以下陈述总是正确的:

sqrt_lo(x) <= floor(sqrt(x)) <= sqrt(x) <= sqrt_hi(x) <= x

请注意,如果我们假设输入永远不会为零,那么 (unsigned) (n != 0) 可以简化为 1 并且语句仍然正确。

这些函数可以在 O(1) 中使用硬件-__builtin_clzll 指令进行计算。他们只给出数字 22x 的精确结果,因此 2566416

在具有高吞吐量双精度浮点支持的现代处理器上,计算参数 x ≤ 253[=57 的整数平方根 ⌊√x⌋ 的最快方法=] 就是将其计算为 (uint32_t)sqrt((double)x)。对于没有 FPU 或慢速双精度支持的处理器的 32 位整数平方根 suitable,请参阅我的 this answer

64 位无符号整数的平方根可以通过首先计算平方根的倒数 1/√x 或 rsqrt(x),使用低精度 table 查找并在 定点 算术中进行多次 Newton-Raphson 迭代。然后将全精度倒数平方根乘以原始参数以得出平方根。 a 的平方根倒数的一般 Newton-Raphson 迭代是 rn+1 = rn + rn * (1 - a * rn2) / 2。这可以用代数形式转化为各种方便的安排。

下面的示例性 C99 代码演示了上述算法的详细信息。使用 规范化 参数的七个最高有效位作为索引,从 96 项查找 table 中检索平方根倒数的八位近似值。规范化需要前导零的计数,这是许多处理器架构上的内置硬件指令,但也可以通过单精度浮点计算或整数计算合理有效地模拟。

为了潜在地允许使用小操作数乘法,初始倒数平方根近似 r0 使用以下 Newton-Raphson 迭代变体进行细化:r1 = (3 * r0 - a * r03) / 2. 第二次迭代 r2 = (r1 * (3 - r1 ( (r1 * a))) / 2 用于进一步细化。第三次迭代与后乘 a 相结合,得到最终的平方根近似值: s2 = a * r2, s3 = s2 + (r2 * (a - s2 * s2)) / 2.

作为最后一步,必须对最终的归一化平方根近似值进行反归一化。向右移动的位数是归一化期间向左移动的位数的一半。结果是低估的,最多可以比真实结果小 2。正确的结果⌊√a⌋,可以通过检验余数的大小来判断。

要在许多 32 位平台上实现良好的性能,前两次 Newton-Raphson 迭代应该在 32 位算法中执行,因为在该状态下只需要有限的精度。可以通过使用具有 96 个 32 位条目的更大 table 来逐步加快计算速度,其中每个条目的最低有效十位存储 3 * r0 和最有效位存储 r03 rounded 到 22 位,这引入了可忽略的错误。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#if defined(_MSC_VER) && defined(_WIN64)
#include <intrin.h>
#endif // defined(_MSC_VER) && defined(_WIN64)

#define CLZ_BUILTIN  (1) // use compiler's built-in count-leading-zeros
#define CLZ_FPU      (2) // emulate count-leading-zeros via FPU
#define CLZ_CPU      (3) // emulate count-leading-zeros via CPU

#define LARGE_TABLE  (1)
#define CLZ_IMPL     (CLZ_CPU)
#define GEN_TAB      (1)

uint32_t umul32_hi (uint32_t a, uint32_t b);
uint64_t umul32_wide (uint32_t a, uint32_t b);
int clz64 (uint64_t a);

#if LARGE_TABLE
uint32_t rsqrt_tab [96] = 
{
    0xfa0bfafa, 0xee6b2aee, 0xe5f02ae5, 0xdaf26ed9, 0xd2f002d0, 0xc890c2c4,
    0xc1037abb, 0xb9a75ab2, 0xb4da42ac, 0xadcea2a3, 0xa6f27a9a, 0xa279c294,
    0x9beb4a8b, 0x97a5ca85, 0x9163427c, 0x8d4fca76, 0x89500270, 0x8563ba6a,
    0x818ac264, 0x7dc4ea5e, 0x7a120258, 0x7671da52, 0x72e4424c, 0x6f690a46,
    0x6db24243, 0x6a52423d, 0x67042637, 0x6563c234, 0x62302a2e, 0x609cea2b,
    0x5d836a25, 0x5bfd1a22, 0x58fd421c, 0x5783ae19, 0x560e4a16, 0x53300210,
    0x51c7120d, 0x50623a0a, 0x4da4c204, 0x4c4c1601, 0x4af769fe, 0x49a6b9fb,
    0x485a01f8, 0x471139f5, 0x45cc59f2, 0x448b5def, 0x4214fde9, 0x40df89e6,
    0x3fade1e3, 0x3e8001e0, 0x3d55e1dd, 0x3c2f79da, 0x3c2f79da, 0x3b0cc5d7,
    0x39edc1d4, 0x38d265d1, 0x37baa9ce, 0x36a689cb, 0x359601c8, 0x348909c5,
    0x348909c5, 0x337f99c2, 0x3279adbf, 0x317741bc, 0x30784db9, 0x30784db9,
    0x2f7cc9b6, 0x2e84b1b3, 0x2d9001b0, 0x2d9001b0, 0x2c9eb1ad, 0x2bb0b9aa,
    0x2bb0b9aa, 0x2ac615a7, 0x29dec1a4, 0x29dec1a4, 0x28fab5a1, 0x2819e99e,
    0x2819e99e, 0x273c599b, 0x273c599b, 0x26620198, 0x258ad995, 0x258ad995,
    0x24b6d992, 0x24b6d992, 0x23e5fd8f, 0x2318418c, 0x2318418c, 0x224d9d89,
    0x224d9d89, 0x21860986, 0x21860986, 0x20c18183, 0x20c18183, 0x20000180,
};
#else // LARGE_TABLE
uint8_t rsqrt_tab [96] = 
{
    0xfe, 0xfa, 0xf7, 0xf3, 0xf0, 0xec, 0xe9, 0xe6, 0xe4, 0xe1, 0xde, 0xdc,
    0xd9, 0xd7, 0xd4, 0xd2, 0xd0, 0xce, 0xcc, 0xca, 0xc8, 0xc6, 0xc4, 0xc2,
    0xc1, 0xbf, 0xbd, 0xbc, 0xba, 0xb9, 0xb7, 0xb6, 0xb4, 0xb3, 0xb2, 0xb0,
    0xaf, 0xae, 0xac, 0xab, 0xaa, 0xa9, 0xa8, 0xa7, 0xa6, 0xa5, 0xa3, 0xa2,
    0xa1, 0xa0, 0x9f, 0x9e, 0x9e, 0x9d, 0x9c, 0x9b, 0x9a, 0x99, 0x98, 0x97,
    0x97, 0x96, 0x95, 0x94, 0x93, 0x93, 0x92, 0x91, 0x90, 0x90, 0x8f, 0x8e,
    0x8e, 0x8d, 0x8c, 0x8c, 0x8b, 0x8a, 0x8a, 0x89, 0x89, 0x88, 0x87, 0x87,
    0x86, 0x86, 0x85, 0x84, 0x84, 0x83, 0x83, 0x82, 0x82, 0x81, 0x81, 0x80,
};
#endif //LARGE_TABLE 

uint32_t my_isqrt64 (uint64_t a)
{
    uint64_t rem, arg = a;
    uint32_t b, r, s, t, scal;

    /* Handle zero as special case */
    if (a == 0ULL) return 0u;
    /* Normalize argument */
    scal = clz64 (a) & ~1;
    a = a << scal;
    b = a >> 32;
    /* Generate initial approximation to 1/sqrt(a) = rsqrt(a) */
    t = rsqrt_tab [(b >> 25) - 32];
    /* Perform first NR iteration for rsqrt */
#if LARGE_TABLE
    r = (t << 22) - umul32_hi (b, t);
#else // LARGE_TABLE
    r = ((3 * t) << 22) - umul32_hi (b, (t * t * t) << 8);
#endif // LARGE_TABLE
    /* Perform second NR iteration for rsqrt */
    s = umul32_hi (r, b);
    s = 0x30000000 - umul32_hi (r, s);
    r = umul32_hi (r, s);
    /* Compute sqrt(a) as a * rsqrt(a); make sure it is an underestimate! */
    r = r * 16;
    s = umul32_hi (r, b);
    s = 2 * s - 10;
    /* Perform third NR iteration combined with back multiply */
    rem = a - umul32_wide (s, s);
    r = umul32_hi ((uint32_t)(rem >> 32), r);
    s = s + r;
    /* Denormalize result */
    s = s >> (scal >> 1);
    /* Make sure we get the floor correct; result underestimates by up to 2 */
    rem = arg - umul32_wide (s, s);
    if (rem >= ((uint64_t)s * 4 + 4)) s+=2;
    else if (rem >= ((uint64_t)s * 2 + 1)) s++;
    return s;
}

uint32_t umul32_hi (uint32_t a, uint32_t b)
{
    return (uint32_t)(((uint64_t)a * b) >> 32);
}

uint64_t umul32_wide (uint32_t a, uint32_t b)
{
    return (uint64_t)a * b;
}

uint32_t float_as_uint32 (float a)
{
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

int clz32 (uint32_t a)
{
#if (CLZ_IMPL == CLZ_FPU)
    // Henry S. Warren, Jr, " Hacker's Delight 2nd ed", p. 105
    int n = 158 - (float_as_uint32 ((float)(int32_t)(a & ~(a >> 1))+.5f) >> 23);
    return (n < 0) ? 0 : n;
#elif (CLZ_IMPL == CLZ_CPU)
    static const uint8_t clz_tab[32] = {
        31, 22, 30, 21, 18, 10, 29,  2, 20, 17, 15, 13, 9,  6, 28, 1,
        23, 19, 11,  3, 16, 14,  7, 24, 12,  4,  8, 25, 5, 26, 27, 0
    };
    a |= a >> 16;
    a |= a >> 8;
    a |= a >> 4;
    a |= a >> 2;
    a |= a >> 1;
    return clz_tab [0x07c4acddu * a >> 27] + (!a);
#else // CLZ_IMPL == CLZ_BUILTIN
#if defined(_MSC_VER) && defined(_WIN64)
    return (int)__lzcnt (a);
#else // defined(_MSC_VER) && defined(_WIN64)
    return (int)__builtin_clz (a);
#endif // defined(_MSC_VER) && defined(_WIN64)
#endif // CLZ_IMPL
}

int clz64 (uint64_t a)
{
#if (CLZ_IMPL == CLZ_BUILTIN)
#if defined(_MSC_VER) && defined(_WIN64)
    return (int)__lzcnt64 (a);
#else // defined(_MSC_VER) && defined(_WIN64)
    return (int)__builtin_clzll (a);
#endif // defined(_MSC_VER) && defined(_WIN64)
#else // CLZ_IMPL
    uint32_t ah = (uint32_t)(a >> 32);
    uint32_t al = (uint32_t)(a);
    int r;
    if (ah) al = ah;
    r = clz32 (al);
    if (!ah) r += 32;
    return r;
#endif // CLZ_IMPL
}

/* Henry S. Warren, Jr., "Hacker's Delight, 2nd e.d", p. 286 */
uint32_t ref_isqrt64 (uint64_t x)
{
    uint64_t m, y, b;
    m = 0x4000000000000000ULL;
    y = 0ULL;
    while (m != 0) {
        b = y | m;
        y = y >> 1;
        if (x >= b) {
            x = x - b;
            y = y | m;
        }
        m = m >> 2;
    }
    return (uint32_t)y;
}

/*
  https://groups.google.com/forum/#!original/comp.lang.c/qFv18ql_WlU/IK8KGZZFJx4J
  From: geo <gmars...@gmail.com>
  Newsgroups: sci.math,comp.lang.c,comp.lang.fortran
  Subject: 64-bit KISS RNGs
  Date: Sat, 28 Feb 2009 04:30:48 -0800 (PST)

  This 64-bit KISS RNG has three components, each nearly
  good enough to serve alone.    The components are:
  Multiply-With-Carry (MWC), period (2^121+2^63-1)
  Xorshift (XSH), period 2^64-1
  Congruential (CNG), period 2^64
*/
static uint64_t kiss64_x = 1234567890987654321ULL;
static uint64_t kiss64_c = 123456123456123456ULL;
static uint64_t kiss64_y = 362436362436362436ULL;
static uint64_t kiss64_z = 1066149217761810ULL;
static uint64_t kiss64_t;
#define MWC64  (kiss64_t = (kiss64_x << 58) + kiss64_c, \
                kiss64_c = (kiss64_x >> 6), kiss64_x += kiss64_t, \
                kiss64_c += (kiss64_x < kiss64_t), kiss64_x)
#define XSH64  (kiss64_y ^= (kiss64_y << 13), kiss64_y ^= (kiss64_y >> 17), \
                kiss64_y ^= (kiss64_y << 43))
#define CNG64  (kiss64_z = 6906969069ULL * kiss64_z + 1234567ULL)
#define KISS64 (MWC64 + XSH64 + CNG64)

#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

int main (void)
{
#if LARGE_TABLE
    printf ("64-bit integer square root implementation w/ large table\n");
#else // LARGE_TAB
    printf ("64-bit integer square root implementation w/ small table\n");
#endif

#if GEN_TAB
    printf ("Generating table ...\n");
    for (int i = 0; i < 96; i++ ) {
        double x1 = 1.0 + i * 1.0 / 32;
        double x2 = 1.0 + (i + 1) * 1.0 / 32;
        double y = (1.0 / sqrt (x1) + 1.0 / sqrt (x2)) * 0.5;
        uint32_t val = (uint32_t)(y * 256 + 0.5);
#if LARGE_TABLE
        uint32_t cube = val * val * val;
        rsqrt_tab[i] = (((cube + 1) / 4) << 10) + (3 * val);
        printf ("0x%08x, ", rsqrt_tab[i]);
        if (i % 6 == 5) printf ("\n");
#else // LARGE_TABLE
        rsqrt_tab[i] = (uint8_t)val;
        printf ("0x%02x, ", rsqrt_tab[i]);
        if (i % 12 == 11) printf ("\n");
#endif // LARGE_TABLE
    }
#endif // GEN_TAB

    printf ("Running benchmark ...\n");

    double start, stop;
    uint32_t sum[8] = {0, 0, 0, 0, 0, 0, 0, 0};
    for (int k = 0; k < 2; k++) {
        uint32_t i = 0;
        start = second();
        do {
            sum [0] += my_isqrt64 (0x31415926ULL * i + 0);
            sum [1] += my_isqrt64 (0x31415926ULL * i + 1);
            sum [2] += my_isqrt64 (0x31415926ULL * i + 2);
            sum [3] += my_isqrt64 (0x31415926ULL * i + 3);
            sum [4] += my_isqrt64 (0x31415926ULL * i + 4);
            sum [5] += my_isqrt64 (0x31415926ULL * i + 5);
            sum [6] += my_isqrt64 (0x31415926ULL * i + 6);
            sum [7] += my_isqrt64 (0x31415926ULL * i + 7);
            i += 8;
        } while (i);
        stop = second();
    }
    printf ("%08x\relapsed=%.5f sec\n", 
            sum[0]+sum[1]+sum[2]+sum[3]+sum[4]+sum[5]+sum[6]+sum[7],
            stop - start);

    printf ("Running functional test ...\n");
    uint64_t a, count = 0;
    uint32_t res, ref;
    do {
        switch (count >> 33) {
        case 0:
            a = count;
            break;
        case 1:
            a = (count & ((1ULL << 33) - 1)) * (count & ((1ULL << 33) - 1) - 1);
            break;
        case 2:
            a = (count & ((1ULL << 33) - 1)) * (count & ((1ULL << 33) - 1));
            break;
        case 3:
            a = (count & ((1ULL << 33) - 1)) * (count & ((1ULL << 33) - 1)) + 1;
            break;
        default:
            a = KISS64;
            break;
        }
        res = my_isqrt64 (a);
        ref = ref_isqrt64 (a);
        if (res != ref) {
            printf ("\nerror: arg=%016llx  res=%08x  ref=%08x  count=%llx\n", a, res, ref, count);
            return EXIT_FAILURE;
        }
        count++;
        if (!(count & 0xffffff)) printf ("\r%016llx", count);
    } while (count);
    printf ("PASSED\n");
    return EXIT_SUCCESS;
}

出于其他目的需要平方根算法,并在搜索时发现了这个线程。我最终得出的结论是 sqrt 与大值几乎是线性的。

如果所需的精度是例如 sqrt(x) > estimate > sqrt(x)-1,则可以使用这样的值: 0、16、146、581、1612、3623、7100 ... 100 个值 ... 549043200、569728768 对于标准 sqrt 函数并在它们之间进行线性化。

注意:以上数值为估计值,可能有误差。目的只是为了展示如果需要几个大小相同的 sqrt 值,可以使用多大的跨度进行线性化。