如何在不溢出的情况下对另一个数字进行算术模运算?
How to do arithmetic modulo another number, without overflow?
我正在尝试为 Rust 的 u32
和 u64
数据类型实施快速素数测试。作为其中的一部分,我需要计算 (n*n)%d
,其中 n
和 d
分别是 u32
(或 u64
)。
虽然结果很容易符合数据类型,但我不知道如何计算它。据我所知,没有处理器原语。
对于 u32
我们可以伪造它——向上转换为 u64
,这样乘积就不会溢出,然后取模,然后向下转换为 u32
,知道这不会溢出。但是,由于我没有 u128
数据类型(据我所知),此技巧不适用于 u64
.
所以对于 u64
,我能想到的最明显的方法是计算 x*y
以获得一对 (carry, product)
of u64
,所以我们捕获溢出量而不是仅仅丢失它(或恐慌,或其他)。
有办法吗?或者其他解决问题的标准方法?
使用简单的数学:
(n*n)%d = (n%d)*(n%d)%d
为了证明这是真的,设置 n = k*d+r
:
n*n%d = k**2*d**2+2*k*d*r+r**2 %d = r**2%d = (n%d)*(n%d)%d
。下面是计算两个相乘数的模数的 Rust 代码,摘自维基百科:
fn mul_mod(mut x: u64, mut y: u64, m: u64) -> u64 {
let mut d = 0_u64;
let mp2 = m >> 1;
x %= m;
y %= m;
for _ in 0..64 {
d = if d > mp2 {
(d << 1) - m
} else {
d << 1
};
if x & 0x8000_0000_0000_0000_u64 != 0 {
d += y;
}
if d > m {
d -= m;
}
x <<= 1;
}
d
}
Richard Rast Wikipedia 版本仅适用于 63 位整数。我扩展了 Boiethios 提供的代码以处理所有 64 位无符号整数。
fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
let msb = 0x8000_0000_0000_0000;
let mut d = 0;
let mp2 = m >> 1;
x %= m;
y %= m;
if m & msb == 0 {
for _ in 0..64 {
d = if d > mp2 {
(d << 1) - m
} else {
d << 1
};
if x & msb != 0 {
d += y;
}
if d >= m {
d -= m;
}
x <<= 1;
}
d
} else {
for _ in 0..64 {
d = if d > mp2 {
d.wrapping_shl(1).wrapping_sub(m)
} else {
// the case d == m && x == 0 is taken care of
// after the end of the loop
d << 1
};
if x & msb != 0 {
let (mut d1, overflow) = d.overflowing_add(y);
if overflow {
d1 = d1.wrapping_sub(m);
}
d = if d1 >= m { d1 - m } else { d1 };
}
x <<= 1;
}
if d >= m { d - m } else { d }
}
}
#[test]
fn test_mul_mod64() {
let half = 1 << 16;
let max = std::u64::MAX;
assert_eq!(mul_mod64(0, 0, 2), 0);
assert_eq!(mul_mod64(1, 0, 2), 0);
assert_eq!(mul_mod64(0, 1, 2), 0);
assert_eq!(mul_mod64(1, 1, 2), 1);
assert_eq!(mul_mod64(42, 1, 2), 0);
assert_eq!(mul_mod64(1, 42, 2), 0);
assert_eq!(mul_mod64(42, 42, 2), 0);
assert_eq!(mul_mod64(42, 42, 42), 0);
assert_eq!(mul_mod64(42, 42, 41), 1);
assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);
assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
assert_eq!(mul_mod64(half, half, half), 0);
assert_eq!(mul_mod64(half+1, half+1, half), 1);
assert_eq!(mul_mod64(max, max, max), 0);
assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
assert_eq!(mul_mod64(1239876, max, max), 0);
assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
assert_eq!(mul_mod64(max, 2948635, max), 0);
assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
assert_eq!(mul_mod64(max-1, max-1, max), 1);
assert_eq!(mul_mod64(2, max/2, max-1), 0);
}
这是另一种方法(现在有 u128
数据类型):
fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
let (a, b, m) = (a as u128, b as u128, m as u128);
((a * b) % m) as u64
}
这种方法仅依赖于 LLVM 的 128 位整数算法。
我喜欢这个版本的一点是,它真的很容易让自己相信该解决方案对整个域都是正确的。由于 a
和 b
是 u64
产品保证适合 u128
,并且由于 m
是 u64
最后保证安全
我不知道与其他方法相比性能如何,但如果它显着变慢,我会感到非常惊讶。如果您真的关心性能,您会想要 运行 一些基准测试,并在任何情况下尝试一些替代方案。
我正在尝试为 Rust 的 u32
和 u64
数据类型实施快速素数测试。作为其中的一部分,我需要计算 (n*n)%d
,其中 n
和 d
分别是 u32
(或 u64
)。
虽然结果很容易符合数据类型,但我不知道如何计算它。据我所知,没有处理器原语。
对于 u32
我们可以伪造它——向上转换为 u64
,这样乘积就不会溢出,然后取模,然后向下转换为 u32
,知道这不会溢出。但是,由于我没有 u128
数据类型(据我所知),此技巧不适用于 u64
.
所以对于 u64
,我能想到的最明显的方法是计算 x*y
以获得一对 (carry, product)
of u64
,所以我们捕获溢出量而不是仅仅丢失它(或恐慌,或其他)。
有办法吗?或者其他解决问题的标准方法?
使用简单的数学:
(n*n)%d = (n%d)*(n%d)%d
为了证明这是真的,设置 n = k*d+r
:
n*n%d = k**2*d**2+2*k*d*r+r**2 %d = r**2%d = (n%d)*(n%d)%d
fn mul_mod(mut x: u64, mut y: u64, m: u64) -> u64 {
let mut d = 0_u64;
let mp2 = m >> 1;
x %= m;
y %= m;
for _ in 0..64 {
d = if d > mp2 {
(d << 1) - m
} else {
d << 1
};
if x & 0x8000_0000_0000_0000_u64 != 0 {
d += y;
}
if d > m {
d -= m;
}
x <<= 1;
}
d
}
Richard Rast
fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
let msb = 0x8000_0000_0000_0000;
let mut d = 0;
let mp2 = m >> 1;
x %= m;
y %= m;
if m & msb == 0 {
for _ in 0..64 {
d = if d > mp2 {
(d << 1) - m
} else {
d << 1
};
if x & msb != 0 {
d += y;
}
if d >= m {
d -= m;
}
x <<= 1;
}
d
} else {
for _ in 0..64 {
d = if d > mp2 {
d.wrapping_shl(1).wrapping_sub(m)
} else {
// the case d == m && x == 0 is taken care of
// after the end of the loop
d << 1
};
if x & msb != 0 {
let (mut d1, overflow) = d.overflowing_add(y);
if overflow {
d1 = d1.wrapping_sub(m);
}
d = if d1 >= m { d1 - m } else { d1 };
}
x <<= 1;
}
if d >= m { d - m } else { d }
}
}
#[test]
fn test_mul_mod64() {
let half = 1 << 16;
let max = std::u64::MAX;
assert_eq!(mul_mod64(0, 0, 2), 0);
assert_eq!(mul_mod64(1, 0, 2), 0);
assert_eq!(mul_mod64(0, 1, 2), 0);
assert_eq!(mul_mod64(1, 1, 2), 1);
assert_eq!(mul_mod64(42, 1, 2), 0);
assert_eq!(mul_mod64(1, 42, 2), 0);
assert_eq!(mul_mod64(42, 42, 2), 0);
assert_eq!(mul_mod64(42, 42, 42), 0);
assert_eq!(mul_mod64(42, 42, 41), 1);
assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);
assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
assert_eq!(mul_mod64(half, half, half), 0);
assert_eq!(mul_mod64(half+1, half+1, half), 1);
assert_eq!(mul_mod64(max, max, max), 0);
assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
assert_eq!(mul_mod64(1239876, max, max), 0);
assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
assert_eq!(mul_mod64(max, 2948635, max), 0);
assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
assert_eq!(mul_mod64(max-1, max-1, max), 1);
assert_eq!(mul_mod64(2, max/2, max-1), 0);
}
这是另一种方法(现在有 u128
数据类型):
fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
let (a, b, m) = (a as u128, b as u128, m as u128);
((a * b) % m) as u64
}
这种方法仅依赖于 LLVM 的 128 位整数算法。
我喜欢这个版本的一点是,它真的很容易让自己相信该解决方案对整个域都是正确的。由于 a
和 b
是 u64
产品保证适合 u128
,并且由于 m
是 u64
最后保证安全
我不知道与其他方法相比性能如何,但如果它显着变慢,我会感到非常惊讶。如果您真的关心性能,您会想要 运行 一些基准测试,并在任何情况下尝试一些替代方案。