如何优化这个 Haskell 代码在亚线性时间内求和素数?

How to optimize this Haskell code summing up the primes in sublinear time?

Project Euler 中的第 10 题是在给定 n.

的情况下求出下面所有素数的和

我简单地把埃拉托色尼筛法生成的素数求和就解决了。然后我发现效率更高 solution by Lucy_Hedgehog(次线性!)。

对于n = 2⋅10^9:

我在 Haskell 中重新实现了相同的算法,因为我正在学习它:

import Data.List

import Data.Map (Map, (!))
import qualified Data.Map as Map

problem10 :: Integer -> Integer
problem10 n = (sieve (Map.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
              where vs = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
                    r  = floor (sqrt (fromIntegral n))

sieve :: Map Integer Integer -> Integer -> Integer -> [Integer] -> Map Integer Integer
sieve m p r vs | p > r     = m
               | otherwise = sieve (if m ! p > m ! (p - 1) then update m vs p else m) (p + 1) r vs

update :: Map Integer Integer -> [Integer] -> Integer -> Map Integer Integer
update m vs p = foldl' decrease m (map (\v -> (v, sumOfSieved m v p)) (takeWhile (>= p*p) vs))

decrease :: Map Integer Integer -> (Integer, Integer) -> Map Integer Integer
decrease m (k, v) = Map.insertWith (flip (-)) k v m

sumOfSieved :: Map Integer Integer -> Integer -> Integer -> Integer
sumOfSieved m v p = p * (m ! (v `div` p) - m ! (p - 1))

main = print $ problem10 $ 2*10^9

我用 ghc -O2 10.hs 和 运行 用 time ./10 编译了它。

它给出了正确答案,但需要大约 7 秒。

我用 ghc -prof -fprof-auto -rtsopts 10 和 运行 用 ./10 +RTS -p -h 编译了它。

10.prof 显示 decrease 占用 52.2% 的时间和 67.5% 的分配。

在 运行ning hp2ps 10.hp 之后,我得到了这样的堆配置文件:

再次看起来 decrease 占用了大部分堆。 GHC 版本 7.6.3.

您将如何优化此 Haskell 代码的 运行 时间?


17 年 6 月 13 日更新:

triedhashtables 包中的可变 Data.HashTable.IO.BasicHashTable 替换了不可变的 Data.Map,但我可能做错了什么,因为对于 tiny n = 30 它已经花费了太长时间,大约 10 秒。怎么了?

17 年 6 月 18 日更新:

Curious about the HashTable performance issues is a good read. I took Sherh's code using mutable Data.HashTable.ST.Linear, but dropped Data.Judy in instead。 运行1.1秒,还是比较慢。

我做了一些小的改进,所以它在我的机器上运行 3.4-3.5 秒。 使用 IntMap.Strict 帮助很大。除此之外,我只是手动执行了一些 ghc 优化以确保安全。并使 Haskell 代码更接近 link 中的 Python 代码。作为下一步,您可以尝试使用一些可变 HashMap。但我不确定... IntMap 不会比某些可变容器快多少,因为它是一个不可变容器。尽管我仍然对它的效率感到惊讶。希望能快点实现。

代码如下:

import Data.List (foldl')
import Data.IntMap.Strict (IntMap, (!))
import qualified Data.IntMap.Strict as IntMap

p :: Int -> Int
p n = (sieve (IntMap.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
               where vs = [n `div` i | i <- [1..r]] ++ [n', n' - 1 .. 1]
                     r  = floor (sqrt (fromIntegral n) :: Double)
                     n' = n `div` r - 1

sieve :: IntMap Int -> Int -> Int -> [Int] -> IntMap Int
sieve m' p' r vs = go m' p'
  where
    go m p | p > r               = m
           | m ! p > m ! (p - 1) = go (update m vs p) (p + 1)
           | otherwise           = go m (p + 1)

update :: IntMap Int -> [Int] -> Int -> IntMap Int
update s vs p = foldl' decrease s (takeWhile (>= p2) vs)
  where
    sp = s ! (p - 1)
    p2 = p * p
    sumOfSieved v = p * (s ! (v `div` p) - sp)
    decrease m  v = IntMap.adjust (subtract $ sumOfSieved v) v m

main :: IO ()
main = print $ p $ 2*10^(9 :: Int) 

更新:

使用可变 hashtables 我已经设法在 Haskell 和 this implementation 上使性能达到 ~5.5sec

另外,我在几个地方使用了未装箱的向量而不是列表。 Linear 哈希似乎是最快的。我认为这可以更快地完成。我注意到 hasthables 包中的 sse42 option。不确定我是否已正确设置它,但即使没有它也运行得那么快。

更新 2 (19.06.2017)

我通过完全删除 judy hashmap 设法使其 3x 比 @Krom 的最佳解决方案(使用我的代码 + 他的地图)更快。相反,只使用普通数组。如果您注意到 S 哈希映射的键是从 1n' 的序列,或者 n div i 对于 i1r。所以我们可以将这样的 HashMap 表示为两个数组,根据搜索键在数组中进行查找。

我的代码+Judy HashMap

$ time ./judy
95673602693282040

real    0m0.590s
user    0m0.588s
sys     0m0.000s

我的代码+我的稀疏图

$ time ./sparse
95673602693282040

real    0m0.203s
user    0m0.196s
sys     0m0.004s

如果不使用 IOUArray 已经生成的向量并且使用 Vector 库并将 readArray 替换为 unsafeRead,则可以更快地完成此操作。但我认为,如果您对尽可能多地优化它并不真正感兴趣,则不应该这样做。

与此解决方案进行比较是作弊,不公平。我希望在 Python 和 C++ 中实现相同的想法会更快。但是@Krom 封闭哈希图的解决方案已经在作弊,因为它使用自定义数据结构而不是标准数据结构。至少你可以看到 Haskell 中标准和最流行的哈希映射并没有那么快。使用更好的算法和更好的临时数据结构可以更好地解决此类问题。

Here's resulting code.

首先作为基准,现有方法的时间安排 在我的机器上:

  1. 原程序贴在问题中:

    time stack exec primorig
    95673602693282040
    
    real    0m4.601s
    user    0m4.387s
    sys     0m0.251s
    
  2. 第二个版本使用 Data.IntMap.Strict 来自 here

    time stack exec primIntMapStrict
    95673602693282040
    
    real    0m2.775s
    user    0m2.753s
    sys     0m0.052s
    
  3. here

    中删除 Data.Judy 的 Shershs 代码
    time stack exec prim-hash2
    95673602693282040
    
    real    0m0.945s
    user    0m0.955s
    sys     0m0.028s
    
  4. 您的 python 解决方案。

    我用

    编译了它
    python -O -m py_compile problem10.py
    

    和时间:

    time python __pycache__/problem10.cpython-36.opt-1.pyc
    95673602693282040
    
    real    0m1.163s
    user    0m1.160s
    sys     0m0.003s
    
  5. 您的 C++ 版本:

    $ g++ -O2 --std=c++11 p10.cpp -o p10
    $ time ./p10
    sum(2000000000) = 95673602693282040
    
    real    0m0.314s
    user    0m0.310s
    sys     0m0.003s
    

我懒得为 slow.hs 提供基准,因为我没有 想要等待它在 运行 时完成,参数为 2*10^9.

亚秒级性能

以下程序 运行 在我的机器上不到一秒钟。

它使用手卷散列映射,它使用封闭散列 线性探测并使用 knuths 哈希函数的一些变体, 参见 here

当然,它有点适合这种情况,因为查找 例如,函数期望搜索到的键存在。

时间:

time stack exec prim
95673602693282040

real    0m0.725s
user    0m0.714s
sys     0m0.047s

首先我实现了我的 hand rolled hashmap 只是为了散列

key `mod` size

并选择了比预期大数倍的尺码 输入,但程序需要 22 秒或更长时间才能完成。

最后是选择哈希函数的问题 适合工作量。

程序如下:

import Data.Maybe
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead)

type Number = Int

data Map = Map { keys :: IOUArray Int Number
               , values :: IOUArray Int Number
               , size :: !Int 
               , factor :: !Int
               }

newMap :: Int -> Int -> IO Map
newMap s f = do
  k <- newArray (0, s-1) 0
  v <- newArray (0, s-1) 0
  return $ Map k v s f 

storeKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
storeKey arr s f key = go ((key * f) `mod` s)
  where
    go :: Int -> IO Int
    go ind = do
      v <- readArray arr ind
      go2 v ind
    go2 v ind
      | v == 0    = do { writeArray arr ind key; return ind; }
      | v == key  = return ind
      | otherwise = go ((ind + 1) `mod` s)

loadKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
loadKey arr s f key = s `seq` key `seq` go ((key *f) `mod` s)
  where
    go :: Int -> IO Int
    go ix = do
      v <- unsafeRead arr ix
      if v == key then return ix else go ((ix + 1) `mod` s)

insertIntoMap :: Map -> (Number, Number) -> IO Map
insertIntoMap m@(Map ks vs s f) (k, v) = do
  ix <- storeKey ks s f k
  writeArray vs ix v
  return m

fromList :: Int -> Int -> [(Number, Number)] -> IO Map
fromList s f xs = do
  m <- newMap s f
  foldM insertIntoMap m xs

(!) :: Map -> Number -> IO Number
(!) (Map ks vs s f) k = do
  ix <- loadKey ks s f k
  readArray vs ix

mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate (Map ks vs s fac) i f = do
  ix <- loadKey ks s fac i
  old <- readArray vs ix
  let x' = f old
  x' `seq` writeArray vs ix x'

r' :: Number -> Number
r'  = floor . sqrt . fromIntegral

vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  

vss' n r = r + n `div` r -1

list' :: Int -> Int -> [Number] -> IO Map
list' s f vs = fromList s f [(i, i * (i + 1) `div` 2 - 1) | i <- vs]

problem10 :: Number -> IO Number
problem10 n = do
      m <- list' (19*vss) (19*vss+7) vs
      nm <- sieve m 2 r vs
      nm ! n
    where vs = vs' n r
          vss = vss' n r
          r  = r' n

sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r     = return m
               | otherwise = do
                   v1 <- m ! p
                   v2 <- m ! (p - 1)
                   nm <- if v1 > v2 then update m vs p else return m
                   sieve nm (p + 1) r vs

update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs

decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
  v <- sumOfSieved m k p
  mupdate m k (subtract v)
  return m

sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
  v1 <- m ! (v `div` p)
  v2 <- m ! (p - 1)
  return $ p * (v1 - v2)

main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9

我不是散列算法之类的专业人士,所以 这当然可以改进很多。也许我们 Haskellers 应该 改进货架哈希映射或提供一些更简单的。

我的 hashmap,Shershs 代码

如果我将 hashmap 插入 Shershs(见下面的答案)代码,请参阅 here 我们甚至

time stack exec prim-hash2
95673602693282040

real    0m0.601s
user    0m0.604s
sys     0m0.034s

为什么 slow.hs 慢?

如果您通读了源代码 对于 Data.HashTable.ST.Basic 中的函数 insert,您 会看到它删除了旧的键值对并插入 一个新的。它不会查找 "place" 的值和 改变它,就像人们想象的那样,如果有人读到它是 "mutable" 哈希表。这里哈希表本身是可变的, 所以你不需要复制整个哈希表来插入 一个新的键值对,但该对的值位置 不是。不知道是不是slow.hs的全部 很慢,但我的猜测是,这是其中相当大的一部分。

一些小改进

这就是我在尝试改进时遵循的想法 第一次看你的节目。

看,您不需要从键到值的可变映射。 您的密钥集是固定的。你想要一个从键到可变的映射 地方。 (顺便说一下,这是默认情况下从 C++ 获得的内容。)

所以我试着想出那个。我使用了 IntMap IORef 来自 Data.IntMap.StrictData.IORef 首先得到时间 的

tack exec prim
95673602693282040

real    0m2.134s
user    0m2.141s
sys     0m0.028s

我认为使用未装箱的值可能会有所帮助 为了得到它,我使用了 IOUArray Int Int 和 1 个元素 每个而不是 IORef 并得到这些时间:

time stack exec prim
95673602693282040

real    0m2.015s
user    0m2.018s
sys     0m0.038s

差别不大,所以我试图摆脱界限 使用 unsafeRead 和检查 1 个元素数组 unsafeWrite 得到了

的时间
time stack exec prim
95673602693282040

real    0m1.845s
user    0m1.850s
sys     0m0.030s

这是我使用 Data.IntMap.Strict 得到的最好的。

当然我运行每个程序多次看是否 时间是稳定的,运行 时间的差异不是 只是噪音。

看起来这些都只是微优化。

这里是 运行 对我来说最快的程序,无需使用手动数据结构:

import qualified Data.IntMap.Strict as M
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead, unsafeWrite)

type Number = Int
type Place = IOUArray Number Number
type Map = M.IntMap Place

tupleToRef :: (Number, Number) -> IO (Number, Place)
tupleToRef = traverse (newArray (0,0))

insertRefs :: [(Number, Number)] -> IO [(Number, Place)]
insertRefs = traverse tupleToRef

fromList :: [(Number, Number)] -> IO Map 
fromList xs = M.fromList <$> insertRefs xs

(!) :: Map -> Number -> IO Number
(!) m i = unsafeRead (m M.! i) 0

mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate m i f = do
  let place = m M.! i
  old <- unsafeRead place 0
  let x' = f old
  -- make the application of f strict
  x' `seq` unsafeWrite place 0 x'

r' :: Number -> Number
r'  = floor . sqrt . fromIntegral

vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  

list' :: [Number] -> IO Map
list' vs = fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]

problem10 :: Number -> IO Number
problem10 n = do
      m <- list' vs
      nm <- sieve m 2 r vs
      nm ! n
    where vs = vs' n r
          r  = r' n

sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r     = return m
               | otherwise = do
                   v1 <- m ! p
                   v2 <- m ! (p - 1)
                   nm <- if v1 > v2 then update m vs p else return m
                   sieve nm (p + 1) r vs

update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs

decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
  v <- sumOfSieved m k p
  mupdate m k (subtract v)
  return m

sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
  v1 <- m ! (v `div` p)
  v2 <- m ! (p - 1)
  return $ p * (v1 - v2)

main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9

如果你分析它,你会发现它大部分时间都花在自定义查找函数中 (!), 不知道如何进一步改进。尝试用 {-# INLINE (!) #-} 内联 (!) 没有产生更好的结果;也许 ghc 已经这样做了。

试试这个,让我知道它有多快:

-- sum of primes

import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed

sieve :: Int -> UArray Int Bool
sieve n = runSTUArray $ do
    let m = (n-1) `div` 2
        r = floor . sqrt $ fromIntegral n
    bits <- newArray (0, m-1) True
    forM_ [0 .. r `div` 2 - 1] $ \i -> do
        isPrime <- readArray bits i
        when isPrime $ do
            let a = 2*i*i + 6*i + 3
                b = 2*i*i + 8*i + 6
            forM_ [a, b .. (m-1)] $ \j -> do
                writeArray bits j False
    return bits

primes :: Int -> [Int]
primes n = 2 : [2*i+3 | (i, True) <- assocs $ sieve n]

main = do
    print $ sum $ primes 1000000

您可以 运行 在 ideone. My algorithm is the Sieve of Eratosthenes, and it should be quite fast for small n. For n = 2,000,000,000, the array size may be a problem, in which case you will need to use a segmented sieve. See my blog for more information about the Sieve of Eratosthenes. See this answer 上查看有关分段筛的信息(不幸的是 Haskell 中没有)。

我的这段代码在 0.3 秒内将总和计算为 2⋅10^9,在 19.6 秒内将总和计算为 10^12 (18435588552550705911377)(如果有足够的 RAM)。

import Control.DeepSeq 
import qualified Control.Monad as ControlMonad
import qualified Data.Array as Array
import qualified Data.Array.ST as ArrayST
import qualified Data.Array.Base as ArrayBase

primeLucy :: (Integer -> Integer) -> (Integer -> Integer) -> Integer -> (Integer->Integer)
primeLucy f sf n = g
  where
    r = fromIntegral $ integerSquareRoot n
    ni = fromIntegral n
    loop from to c = let go i = ControlMonad.when (to<=i) (c i >> go (i-1)) in go from

    k = ArrayST.runSTArray $ do
      k <- ArrayST.newListArray (-r,r) $ force $
        [sf (div n (toInteger i)) - sf 1|i<-[r,r-1..1]] ++
        [0] ++
        [sf (toInteger i) - sf 1|i<-[1..r]]
      ControlMonad.forM_ (takeWhile (<=r) primes) $ \p -> do
        l <- ArrayST.readArray k (p-1)
        let q = force $ f (toInteger p)

        let adjust = \i j -> do { v <- ArrayBase.unsafeRead k (i+r); w <- ArrayBase.unsafeRead k (j+r); ArrayBase.unsafeWrite k (i+r) $!! v+q*(l-w) }

        loop (-1)         (-div r p)              $ \i -> adjust i (i*p)
        loop (-div r p-1) (-min r (div ni (p*p))) $ \i -> adjust i (div (-ni) (i*p))
        loop r            (p*p)                   $ \i -> adjust i (div i p)

      return k

    g :: Integer -> Integer
    g m
      | m >= 1 && m <= integerSquareRoot n                       = k Array.! (fromIntegral m)
      | m >= integerSquareRoot n && m <= n && div n (div n m)==m = k Array.! (fromIntegral (negate (div n m)))
      | otherwise = error $ "Function not precalculated for value " ++ show m

primeSum :: Integer -> Integer
primeSum n = (primeLucy id (\m -> div (m*m+m) 2) n) n

如果您的 integerSquareRoot 函数有错误(据报道有些错误),您可以在此处将其替换为 floor . sqrt . fromIntegral

解释:

顾名思义,它是基于 "Lucy Hedgehog" 的著名方法的概括,最终由原始发布者发现。

它允许您计算许多形式为 的和(使用 p 个素数),而无需枚举最多 N 个素数且时间为 O(N^0.75)。

它的输入是函数 f(即,id,如果你想要质数和),它对所有整数的求和函数(即,在这种情况下,前 m 个整数的总和或 div (m*m+m) 2), 和 N.

PrimeLucy returns 一个查找函数 (with p prime) restricted to certain values of n: .