R 中是否有比 stats:uniroot 函数更快的替代方法?

Is there any faster alternative to stats:uniroot function in R?

我 运行 stats::uniroot 在 data.table 中处理一百万行数据。这是一个玩具示例 -

library(data.table)
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 50000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

print(system.time(
dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
))

在上面的代码中,50,000 行花费的时间接近 8 秒。

有没有比 uniroot 函数更快的替代方法可以大大减少这个时间?

160 秒 (1e6/5e4 * 8) 对我来说对于一百万行来说并没有那么糟糕(尽管你的实际功能可能比你在这里使用的 froot 慢得多?)。这可以简单地并行化,运行 单独的块在不同的核心上(参见 this question 的答案)。

你有多需要extendInt?如果我制作一个仅具有其核心功能的 uniroot() 函数的黑客版本,none 参数测试逻辑等,我可以将速度提高三倍。但是,您的速度增益会少得多如果您的目标函数比您在此处给出的示例慢得多,则令人印象深刻;如果是这种情况,你应该专注于加速你的目标函数(我尝试通过 Rcpp 在 C++ 中重新编码你的 froot,但在这种情况下它并没有真正帮助 - 该函数非常简单函数调用开销占用了大部分时间...)

为了便于基准测试,我只用了 5000 行:

n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

最小函数:

uu <- function(f, lower, upper, tol = 1e-8, maxiter =1000L, ...) {
  f.lower <- f(lower, ...)
  f.upper <- f(upper, ...)
  val <- .External2(stats:::C_zeroin2, function(arg) f(arg, ...),
                    lower, upper, f.lower, f.upper, tol, as.integer(maxiter))
  return(val[1])
}

检查我们得到的结果是否相同:

identical(uniroot(froot, u = 3.242, a=0.5, b=1, interval = c(0.01,100))$root,
          uu(froot, u = 3.242, a=0.5, b=1, lower = 0.01, upper = 100))
## TRUE

基准测试包;将评估包装在函数中以获得紧凑性

library(rbenchmark)
f1 <- function() {
  dt[, c := uniroot(froot_cpp, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
}
f2 <- function() {
  dt[, c := uu(froot, u=u, a=a, b=b, lower = 0.01, upper = 100), by = u]
}
bb <- benchmark(f1(), f2(), 
    columns =c("test", "replications", "elapsed", "relative"))

结果:

  test replications elapsed relative
1 f1()          100  34.616    3.074
2 f2()          100  11.261    1.000

请注意,所示函数的反函数可以显式计算为

f2 <- function(x) (b^a * x / b)^(1/a)
a <- 1/2
b <- 1
all.equal(f(.5), f2(.5))  # f defined below using uniroot
## [1] TRUE

然而,假设在现实中你有一个更复杂的函数,我们可以使用 Chebyshev 近似来得到它的近似值。请注意,a 和 b 是问题中的常量,因此我们还假设是下面的情况,即 f 使用全局环境中设置的常量 a 和 b。下面的代码运行速度比具有 9 次多项式的基准问题中的代码快近 100 倍,并且在 uniroot 给出的答案的 1e-4 范围内。如果您需要更高的精度,请使用更高的度数。

library(data.table)
library(pracma)
set.seed(123)

cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

dt2 <- copy(dt)
f <- function(u) {
  uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root
}

library(microbenchmark)
microbenchmark(times = 10,
  orig = dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u],
  cheb = dt2[, c := chebApprox(u, Vectorize(f), min(u), max(u), 9)]
)
## Unit: milliseconds
##  expr      min       lq      mean    median       uq      max neval cld
##  orig 943.5323 948.9321 961.00361 958.91970 972.6308 982.0060    10   b
##  cheb   9.3752   9.7513  10.67386  10.02555  10.3411  16.9475    10  a 

max(abs(dt$c - dt2$c))
## [1] 8.081021e-05

确切问题有很好的答案,但有一些关于一般 R 实践的注释。

当顺序无关紧要时使用 by

在 OP 中,我们使用 by = u 以便每一行一次 运行 一行。这是低效的! data.table 将对 u 进行排序,确定分组,并且由于它们是非常随机的真实数字,因此最终得到与行一样多的分组。

相反,我们可以使用 Map()mapply() 遍历行,这将提高性能。请注意,尚不清楚 ab 是否真的因行而异 - 如果它们确实是常量,我们可能希望将它们从 data.table 中取出并将它们作为常量传递。

uniroot2 = function(...) uniroot(...)$root ## helper function
dt[, c2 := mapply(uniroot2, u, a,b,
                  MoreArgs = list (f = froot,
                                   interval = c(0.01, 10),
                                   extendInt = 'yes'))]

## for n = 5000

## # A tibble: 2 x 13
##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
##  1 OP           1.17s   1.17s     0.851     170KB     2.55     1     3      1.17s
##  2 no_by      857.2ms 857.2ms     1.17      214KB     3.50     1     3    857.2ms
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled. 

请注意,一旦我们在 mapply 中设置了它,使用 future.apply::future_mapply() 来并行化我们的调用就很简单了。这比我笔记本电脑上的 no_by 示例快 2.5 倍。

library(future.apply)
plan(multisession)
dt[, c3 := future_mapply(uniroot2, u, a,b,
                  MoreArgs = list (f = froot,
                                   interval = c(0.01, 10),
                                   extendInt = 'yes')
                  , future.globals = "cumhaz")] ## see next section for how we could remove this

函数调用需要时间

在您的示例中,您将两个函数定义为:

cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

当性能成为问题并且简化起来微不足道时,您可能想要简化。

froot2 = function(x, u, a, b) b * (x / b) ^ a - u

超过一百万次循环,对 cumhaz() 的额外调用加起来:

x = 2.5; u = 1.5; a = 0.5; b = 1 
bench::mark(froot_rep = for (i in 1:1e6) {froot(x=x, u=u, a=a, b=b)},
            froot2_rep = for (i in 1:1e6) {froot2(x=x, u=u, a=a, b=b)})

## # A tibble: 2 x 13
##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
## 1 froot_rep    4.74s   4.74s     0.211    13.8KB     3.38     1    16      4.74s
## 2 froot2_rep   3.17s   3.17s     0.315    13.8KB     2.84     1     9      3.17s
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled. 

因为 uniroot 会进一步增加默认最大迭代次数为 1,000 的调用!这意味着 cumhaz() 在优化过程中花费了我们 1.5 到 1,500 秒之间的时间。作为@G。 Grothendieck 指出,有时我们实际上可以直接求解并使用直接向量化方法,而不是依赖 unirootoptimize.