高效Haskell相当于NumPy的argsort

Efficient Haskell equivalent to NumPy's argsort

是否有与 NumPy 的 argsort 函数等效的标准 Haskell?

我正在使用 HMatrix,因此,我想要一个与 Vector R 兼容的函数,它是 Data.Vector.Storable.Vector Double 的别名。下面的 argSort 函数是我目前使用的实现:

{-# LANGUAGE NoImplicitPrelude #-}

module Main where

import qualified Data.List as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import           Prelude (($), Double, IO, Int, compare, print, snd)

a :: VS.Vector Double
a = VS.fromList [40.0, 20.0, 10.0, 11.0]

argSort :: VS.Vector Double -> V.Vector Int
argSort xs = V.fromList (L.map snd $ L.sortBy (\(x0, _) (x1, _) -> compare x0 x1) (L.zip (VS.toList xs) [0..]))

main :: IO ()
main = print $ argSort a -- yields [2,3,1,0]

我使用显式限定 import 只是为了清楚说明每个类型和函数的来源。

此实现不是非常有效,因为它将输入向量转换为列表并将结果转换回向量。某处是否存在这样的东西(但效率更高)?

更新

@leftaroundabout 有一个很好的解决方案。这是我最终得到的解决方案:

module LAUtil.Sorting
  ( IndexVector
  , argSort
  )
  where

import           Control.Monad
import           Control.Monad.ST
import           Data.Ord
import qualified Data.Vector.Algorithms.Intro as VAI
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import           Numeric.LinearAlgebra

type IndexVector = VU.Vector Int

argSort :: Vector R -> IndexVector
argSort xs = runST $ do
    let l = VS.length xs
    t0 <- VUM.new l
    forM_ [0..l - 1] $
        \i -> VUM.unsafeWrite t0 i (i, (VS.!) xs i)
    VAI.sortBy (comparing snd) t0
    t1 <- VUM.new l
    forM_ [0..l - 1] $
        \i -> VUM.unsafeRead t0 i >>= \(x, _) -> VUM.unsafeWrite t1 i x
    VU.freeze t1

由于数据向量是 Storable,因此可以更直接地用于 Numeric.LinearAlgebra。这使用未装箱的向量作为索引。

使用vector-algorithms:

import Data.Ord (comparing)

import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Algorithms.Intro as VAlgo

argSort :: (Ord a, VU.Unbox a) => VU.Vector a -> VU.Vector Int
argSort xs = VU.map fst $ VU.create $ do
    xsi <- VU.thaw $ VU.indexed xs
    VAlgo.sortBy (comparing snd) xsi
    return xsi

请注意,这些是 Unboxed 而不是 Storable 向量。后者需要做出一些权衡以允许不纯的 C FFI 操作并且不能正确处理异构元组。您当然可以始终 convert 往返于可存储向量。

对我来说效果更好的是使用 Data.map,因为它受列表融合的影响,所以速度提高了。这里 n = 长度 xs。

import Data.Map as M (toList, fromList, toAscList)

    out :: Int -> [Double] -> [Int]
    out n !xs = let !a=  (M.toAscList (M.fromList $! (zip xs [0..n])))
                    !res=a `seq` L.map snd a
                in res

然而,这仅适用于定期列表,如:

out 12 [1,2,3,4,1,2,3,4,1,2,3,4] == out 12 [1,2,3,4,1,3,2,4,1,2,3,4]