Z3:如何提高性能?

Z3: how to improve performance?

在下面的 Python 代码中:

from itertools import product
from z3 import *

def get_sp(v0, v1):
    res = sum([v0[i] * v1[i] for i in range(len(v0))])
    return res

def get_is_mod_partition(numbers):
    n = len(numbers)
    mod_power = 2**n
    for signs in product((1, -1), repeat = len(numbers)):
        if get_sp(numbers, signs) % mod_power == 0:
            return 1
    return 0

def check_sat(numbers):
    n = len(numbers)
    s = Solver()
    signs = [Int("s" + str(i)) for i in range(n)]
    for i in range(n):
        s.add(Or(signs[i] == -1, signs[i] == 1))
    s.add(get_sp(numbers, signs) % 2**n == 0)
    print(s.check())

l = [1509, 1245, 425, 2684, 3630, 435, 875, 2286, 1886, 1205, 518, 1372]
check_sat(l)
get_is_mod_partition(l)

check_sat 需要 22 秒,get_is_mod_partition - 24 毫秒。我没想到 "high-performance theorem prover" 会有这样的结果。有没有办法大幅提高性能?

按照 Patrick 的建议,您可以按如下方式编写代码:

from z3 import *

def get_sp(v0, v1):
    res = sum([If(v1[i], v0[i], -v0[i]) for i in range(len(v0))])
    return res

def check_sat(numbers):
    n = len(numbers)
    s = Solver()
    signs = [Bool("s" + str(i)) for i in range(n)]
    s.add(get_sp(numbers, signs) % 2**n == 0)
    print(s.check())
    m = s.model()
    mod_power = 2 ** n
    print ("("),
    for (n, sgn) in zip (numbers, signs):
        if m[sgn]:
           print ("+ %d" % n),
        else:
           print ("- %d" % n),
    print (") %% %d == 0" % mod_power)

l = [1509, 1245, 425, 2684, 3630, 435, 875, 2286, 1886, 1205, 518, 1372]
check_sat(l)

在我的机器上运行大约 0.14 秒,并打印:

sat
( - 1509 - 1245 - 425 + 2684 + 3630 + 435 - 875 + 2286 - 1886 - 1205 - 518 - 1372 ) % 4096 == 0

然而,正如帕特里克评论的那样,不清楚为什么这个版本比原来的版本快得多。我想做一些基准测试,并使用 Haskell 这样做,因为我更熟悉该语言及其 Z3 绑定:

import Data.SBV
import Criterion.Main

ls :: [SInteger]
ls = [1509, 1245, 425, 2684, 3630, 435, 875, 2286, 1886, 1205, 518, 1372]

original = do bs <- mapM (const free_) ls
              let inside b = constrain $ b .== 1 ||| b .== -1
              mapM_ inside bs
              return $ sum [b*l | (b, l) <- zip bs ls] `sMod` (2^length ls) .== 0

boolOnly = do bs <- mapM (const free_) ls
              return $ sum [ite b l (-l) | (b, l) <- zip bs ls] `sMod` (2^length ls) .== 0

main = defaultMain [ bgroup "problem" [ bench "orig" $ nfIO (sat original)
                                      , bench "bool" $ nfIO (sat boolOnly)
                                      ]
                   ]

事实上,仅布尔版本的速度快了大约 8 倍:

benchmarking problem/orig
time                 810.1 ms   (763.4 ms .. 854.7 ms)
                     0.999 R²   (NaN R² .. 1.000 R²)
mean                 808.4 ms   (802.2 ms .. 813.6 ms)
std dev              8.189 ms   (0.0 s .. 8.949 ms)
variance introduced by outliers: 19% (moderately inflated)

benchmarking problem/bool
time                 108.2 ms   (104.4 ms .. 113.5 ms)
                     0.997 R²   (0.992 R² .. 1.000 R²)
mean                 109.3 ms   (107.3 ms .. 111.5 ms)
std dev              3.408 ms   (2.261 ms .. 4.843 ms)

两个观察结果:

  • Haskell 绑定比 Python 绑定快得多!大约一个数量级
  • 与整数相比,仅限布尔的版本大约快另一个数量级

对于前者,找出为什么 Python 绑定表现如此糟糕可能会很有趣;或者简单地切换到 Haskell :-)

进一步分析,还有一个可爱的把戏

看来问题出在对 mod 的调用上。在Haskell翻译中,系统内部给所有的中间表达式命名;这似乎使 z3 运行得更快。然而,Python 绑定更全面地翻译表达式,并且检查生成的代码(您可以通过查看 s.sexpr() 看到它)表明它没有命名内部表达式。当涉及到 mod 时,我猜测求解器的启发式算法无法识别问题的基本线性并最终花费大量时间。

要缩短时间,您可以执行以下简单技巧。原文说:

s.add(get_sp(numbers, signs) % 2**n == 0)

相反,请确保总和有一个明确的名称。即,将上面一行替换为:

ssum = Int("ssum")
s.add (ssum == get_sp(numbers, signs))
s.add (ssum % 2**n == 0)

您会发现这也使 Python 版本 运行 更快。

我仍然更喜欢布尔翻译,但是这个翻译提供了一个很好的经验法则:尝试命名中间表达式,这为求解器提供了在公式语义指导下的不错的选择点;而不是一个庞大的输出。正如我提到的,Haskell 绑定不受此影响,因为它在内部将所有公式转换为简单的三操作数 SSA 形式,这使求解器可以更轻松地应用启发式算法。