在 Julia 中使用元编程优化递归函数

Optimizing a recursive function with metaprogramming in Julia

遵循 I am trying to understand what happens exactly and how expressions and generated functions work in Julia within the concept of metaprogramming的方法。

目标是使用表达式和生成的函数优化递归函数(有关具体示例,您可以查看上面提供的 link 中回答的问题)。

考虑以下修改后的斐波那契函数,我想在其中计算斐波那契数列直到 n 并将其乘以一个数字 p

直接的递归实现是

function fib(n::Integer, p::Real)
    if n <= 1
        return 1 * p
    else
        return n * fib(n-1, p)
    end
end

作为第一步,我可以定义一个函数,其中 returns 一个 表达式 而不是计算值

function fib_expr(n::Integer, p::Symbol)
    if n <= 1
        return :(1 * $p)
    else
        return :($n * $(fib_expr(n-1, p)))
    end
end

哪个,例如returns 类似

julia> ex = fib_expr(3, :myp)
:(3 * (2 * (1myp)))

通过这种方式,我得到了一个 完全展开 的表达式,它取决于分配给符号 myp 的值。通过这种方式,我不再看到递归,基本上我是 元编程 :我创建了一个创建另一个 "function" 的函数(尽管在这种情况下我们称之为表达式)。 我现在可以设置 myp = 0.5 并调用 eval(ex) 来计算结果。 但是,这比第一种方法

不过我能做的是按以下方式生成参数函数

@generated function fib_gen{n}(::Type{Val{n}}, p::Real)
    return fib_expr(n, :p)
end

神奇的是,调用 fib_gen(Val{3}, 0.5) 可以完成任务,而且速度快得令人难以置信。
所以,这是怎么回事?

据我了解,在第一次调用fib_gen(Val{3}, 0.5)时,参数函数fib_gen{Val{3}}(...)被编译,其内容是通过fib_expr(3, :p)获得的完全扩展表达式,即3*2*1*p 并用 p 替换为输入值。 它之所以这么快,是因为 fib_gen 基本上只是一系列乘法,而原来的 fib 必须在堆栈上分配每个递归调用使其变慢,我说的对吗?

为了给出一些数字,这是我的简短基准 using BenchmarkTools

julia> @benchmark fib(10, 0.5)
...
mean time: 26.373 ns
...

julia> p = 0.5
0.5

julia> @benchmark eval(fib_expr(10, :p))
...
mean time: 177.906 μs
...

julia> @benchmark fib_gen(Val{10}, 0.5)
...
mean time: 2.046 ns
...

我有很多问题:

此外,我尝试根据

fib_exprfib_gen组合在一个函数中
@generated function fib_tot{n}(::Type{Val{n}}, p::Real)
    if n <= 1
        return :(1 * p)
    else
        return :(n * fib_tot(Val{n-1}, p))
    end
end

但是速度很慢

julia> @benchmark fib_tot(Val{10}, 0.5)
...
mean time: 4.601 μs
...

我在这里做错了什么?甚至可以将 fib_exprfib_gen 组合在一个函数中吗?

我意识到这更像是一本专着而不是一个问题,但是,即使我读了几次 metaprogramming 部分,我也很难掌握所有内容,尤其是应用示例就像这个。

专着回应:

元编程基础知识

首先使用 "normal" 宏会更容易。我将放宽您使用的定义:

function fib_expr(n::Integer, p)
    if n <= 1
        return :(1 * $p)
    else
        return :($n * $(fib_expr(n-1, p)))
    end
end

这允许传递的不仅仅是 p 的符号,例如整数文字或整个表达式。鉴于此,我们可以为相同的功能定义一个宏:

macro fib_macro(n::Integer, p)
    fib_expr(n, p)
end

现在,如果在代码中的任何地方使用 @fib_macro 45 1,在编译时它将首先被一个长嵌套表达式替换

:(45 * (44 * ... * (1 * 1)) ... )

然后正常编译为常量。

真的,这就是宏的全部内容。在编译期间替换语法;通过递归,这可以是编译和计算表达式函数之间任意长的改变。对于本质上不变的东西,但写起来很乏味,它非常有用:一个很好的例子是 Base.Math.@evalpoly.

运行时求值?

但是它有一个问题,你不能检查只有在运行时才知道的值:你不能实现 fib(n) = @fib_macro n 1,因为在编译时,n 是一个代表参数的符号,而不是您可以发送的号码。

下一个最佳解决方案是使用

fib_eval(n::Integer) = eval(fib_expr(n, 1))

有效,但每次调用时都会重复编译过程——这比原始函数的开销大得多,因为现在在运行时,我们执行表达式树上的整个递归,然后在结果上调用编译器。不好。

方法调度&编译

所以我们需要一种混合运行时和编译时的方法。输入 @generated 函数。这些将在运行时分派 type,然后像定义函数体的宏一样工作。

首先是类型调度。如果我们有

f(x) = x + 1

并且有一个函数调用f(1),大约会发生以下情况:

  1. 参数的类型已确定(Int)
  2. 参考函数的方法table,找到最匹配的方法
  3. 方法体是为特定的 Int 参数类型编译的,如果之前没有这样做的话
  4. 编译后的方法在具体参数上进行评估

如果我们随后输入 f(1.0),同样的情况会再次发生,基于相同的函数体,会为 Float64 编译一个新的、不同的专用方法。

值类型和单例类型

现在,Julia 具有您可以将数字用作类型的独特功能。这意味着上面概述的调度过程也适用于以下功能:

g(::Type{Val{N}}) where N = N + 1

这有点棘手。请记住,类型本身就是 Julia 中的值:Int isa Type

在这里,Val{N} 是针对每个 N 一个所谓的 单例类型 ,只有一个实例,即 Val{N}() -- 只是像 Int 是一个有很多实例的类型 0, -1, 1, -2, ....

Type{T} 也是一个单例类型,其单个实例 类型 TInt 是一个 Type{Int},而 Val{3} 是一个 Type{Val{3}} —— 事实上,它们都是它们类型的唯一值。

因此,对于每个 N,都有一个类型 Val{N},是 Type{Val{N}} 的单个实例。因此,g 将针对每个 N 进行分派和编译。这就是我们如何将数字作为类型进行分派。这已经允许优化:

julia> @code_llvm g(Val{1})

define i64 @julia_g_61158(i8**) #0 !dbg !5 {
top:
  ret i64 2
}

julia> @code_llvm f(1)

define i64 @julia_f_61076(i64) #0 !dbg !5 {
top:
  %1 = shl i64 %0, 2
  %2 = or i64 %1, 3
  %3 = mul i64 %2, %0
  %4 = add i64 %3, 2
  ret i64 %4
}

但请记住,它需要在第一次调用时对每个新 N 进行编译。

(如果你不在正文中使用 xfkt(::T) 只是 fkt(x::T) 的缩写。)

集成生成函数和值类型

终于到了生成的函数。它们是对上述调度模式的轻微修改:

  1. 参数的类型已确定(Int)
  2. 参考函数的方法table,找到最匹配的方法
  3. 方法主体被视为宏,使用 Int 参数类型作为参数调用,如果之前没有这样做的话。生成的表达式被编译成一个方法。
  4. 编译后的方法在具体参数上进行评估

此模式允许更改分派函数的每种类型的实现。

对于我们的具体设置,我们要分派 Val 代表斐波那契数列参数的类型:

@generated function fib_gen{n}(::Type{Val{n}}, p::Real)
    return fib_expr(n, :p)
end

你现在看到你的解释是完全正确的:

in the first call to fib_gen(Val{3}, 0.5), the parametric function fib_gen{Val{3}}(...) gets compiled and its content is the fully expanded expression obtained through fib_expr(3, :p), i.e. 3*2*1*p with p substituted with the input value.

我希望整个故事也已经回答了您列出的所有三个问题:

  1. 使用eval的实现每次都复制递归,加上编译的开销
  2. Val 是将数字提升为类型的技巧,而 Type{T} 仅包含 T 的单例类型——但我希望这些示例足够有用
  3. 编译时间 不是在执行之前,因为 JIT -- 它是每次方法第一次编译时,因为它被调用了。

首先,我加入评论:你的问题写得很好,很有建设性。


我已经使用 Julia 0.7-beta 重现了您的结果。

  • @generated fib_tot(一段代码)fib_gen(调用fib_expr)的区别

我的 julia 版本结果是相同的:

julia> @btime fib_tot(Val{10},0.5)
  0.042 ns (0 allocations: 0 bytes)
1.8144e6

julia> @btime fib_gen(Val{10},0.5)
  0.042 ns (0 allocations: 0 bytes)
1.8144e6

有时会将一个函数分解成多个部分 see official doc:performance tips can be useful, however in your peculiar case I do not see why this could be useful. At compile time Julia has everything it needs to optimize fib_tot. There is a branch if n<=1 however n is known at "compile time" thanks to the Type{Val{n}} 技巧,并且应该在生成的(专用)代码中毫无问题地删除此分支。

  • Type{Val{n}} 技巧

为了专门化函数,Julia 推理是根据 参数类型 而不是根据 参数值 .

执行的

例如,不会为每个 n 值生成 foo(n::Int) = ... 的编译版本。您必须定义一个依赖于 n 值的类型才能达到此目标。这正是 Type{Val{n}} 的工作原理:Val{n} 只是一个参数化的空结构:

struct Val{T} end

因此,每个 Val{1}Val{2}、... Val{100}、... 都是不同的类型。因此,如果 foo 定义为:

foo(::Type{Val{n}}) where {n} = ...

每个 foo(Val{1})foo(Val{2})、... foo(Val{100}) 将触发一个专门的 foo 版本(因为参数 type 不同)。

  • eval(fib_expr(n, 1))案例

这个

julia> @btime eval(fib_expr(10, :p))
  401.651 μs (99 allocations: 6.45 KiB)
1.8144e6

很慢,因为你的表达式每次都被(重新)编译。如果您改用宏(请参阅 phg 答案),则可以避免该问题。

  • fib版本

.

julia> @btime fib(10,0.5)
  30.778 ns (0 allocations: 0 bytes)
1.8144e6

fib函数只有一个编译版本。因此,它必须包含所有运行时分支测试等...这解释了它有多慢。


请注意:

  • foo{n}(::Type{Val{n}}) 已弃用的语法

foo{n}(::Type{Val{n}}) 语法已弃用,新语法是 foo(::Type{Val{n}}) where {n}。您可以阅读 Julia doc, parametric methods 了解更多详情。


我的 Julia 版本:

julia> versioninfo()

Julia Version 0.7.0-beta.0
Commit f41b1ecaec (2018-06-24 01:32 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU E5-2603 v3 @ 1.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.0 (ORCJIT, haswell)