在 Julia 中使用元编程优化递归函数
Optimizing a recursive function with metaprogramming in Julia
遵循 I am trying to understand what happens exactly and how expressions and generated functions work in Julia within the concept of metaprogramming的方法。
目标是使用表达式和生成的函数优化递归函数(有关具体示例,您可以查看上面提供的 link 中回答的问题)。
考虑以下修改后的斐波那契函数,我想在其中计算斐波那契数列直到 n
并将其乘以一个数字 p
。
直接的递归实现是
function fib(n::Integer, p::Real)
if n <= 1
return 1 * p
else
return n * fib(n-1, p)
end
end
作为第一步,我可以定义一个函数,其中 returns 一个 表达式 而不是计算值
function fib_expr(n::Integer, p::Symbol)
if n <= 1
return :(1 * $p)
else
return :($n * $(fib_expr(n-1, p)))
end
end
哪个,例如returns 类似
julia> ex = fib_expr(3, :myp)
:(3 * (2 * (1myp)))
通过这种方式,我得到了一个 完全展开 的表达式,它取决于分配给符号 myp
的值。通过这种方式,我不再看到递归,基本上我是 元编程 :我创建了一个创建另一个 "function" 的函数(尽管在这种情况下我们称之为表达式)。
我现在可以设置 myp = 0.5
并调用 eval(ex)
来计算结果。
但是,这比第一种方法慢。
不过我能做的是按以下方式生成参数函数
@generated function fib_gen{n}(::Type{Val{n}}, p::Real)
return fib_expr(n, :p)
end
神奇的是,调用 fib_gen(Val{3}, 0.5)
可以完成任务,而且速度快得令人难以置信。
所以,这是怎么回事?
据我了解,在第一次调用fib_gen(Val{3}, 0.5)
时,参数函数fib_gen{Val{3}}(...)
被编译,其内容是通过fib_expr(3, :p)
获得的完全扩展表达式,即3*2*1*p
并用 p
替换为输入值。
它之所以这么快,是因为 fib_gen
基本上只是一系列乘法,而原来的 fib
必须在堆栈上分配每个递归调用使其变慢,我说的对吗?
为了给出一些数字,这是我的简短基准 using BenchmarkTools
。
julia> @benchmark fib(10, 0.5)
...
mean time: 26.373 ns
...
julia> p = 0.5
0.5
julia> @benchmark eval(fib_expr(10, :p))
...
mean time: 177.906 μs
...
julia> @benchmark fib_gen(Val{10}, 0.5)
...
mean time: 2.046 ns
...
我有很多问题:
- 为什么第二种情况这么慢?
::Type{Val{n}}
到底是什么意思? (我从上面的答案 link 中复制了它)
- 由于 JIT 编译器,有时我会迷失在 编译时 和 运行 时 发生的事情],这里就是这种情况...
此外,我尝试根据
将fib_expr
和fib_gen
组合在一个函数中
@generated function fib_tot{n}(::Type{Val{n}}, p::Real)
if n <= 1
return :(1 * p)
else
return :(n * fib_tot(Val{n-1}, p))
end
end
但是速度很慢
julia> @benchmark fib_tot(Val{10}, 0.5)
...
mean time: 4.601 μs
...
我在这里做错了什么?甚至可以将 fib_expr
和 fib_gen
组合在一个函数中吗?
我意识到这更像是一本专着而不是一个问题,但是,即使我读了几次 metaprogramming 部分,我也很难掌握所有内容,尤其是应用示例就像这个。
专着回应:
元编程基础知识
首先使用 "normal" 宏会更容易。我将放宽您使用的定义:
function fib_expr(n::Integer, p)
if n <= 1
return :(1 * $p)
else
return :($n * $(fib_expr(n-1, p)))
end
end
这允许传递的不仅仅是 p
的符号,例如整数文字或整个表达式。鉴于此,我们可以为相同的功能定义一个宏:
macro fib_macro(n::Integer, p)
fib_expr(n, p)
end
现在,如果在代码中的任何地方使用 @fib_macro 45 1
,在编译时它将首先被一个长嵌套表达式替换
:(45 * (44 * ... * (1 * 1)) ... )
然后正常编译为常量。
真的,这就是宏的全部内容。在编译期间替换语法;通过递归,这可以是编译和计算表达式函数之间任意长的改变。对于本质上不变的东西,但写起来很乏味,它非常有用:一个很好的例子是 Base.Math.@evalpoly.
运行时求值?
但是它有一个问题,你不能检查只有在运行时才知道的值:你不能实现 fib(n) = @fib_macro n 1
,因为在编译时,n
是一个代表参数的符号,而不是您可以发送的号码。
下一个最佳解决方案是使用
fib_eval(n::Integer) = eval(fib_expr(n, 1))
有效,但每次调用时都会重复编译过程——这比原始函数的开销大得多,因为现在在运行时,我们执行表达式树上的整个递归,然后在结果上调用编译器。不好。
方法调度&编译
所以我们需要一种混合运行时和编译时的方法。输入 @generated
函数。这些将在运行时分派 type,然后像定义函数体的宏一样工作。
首先是类型调度。如果我们有
f(x) = x + 1
并且有一个函数调用f(1)
,大约会发生以下情况:
- 参数的类型已确定(
Int
)
- 参考函数的方法table,找到最匹配的方法
- 方法体是为特定的
Int
参数类型编译的,如果之前没有这样做的话
- 编译后的方法在具体参数上进行评估
如果我们随后输入 f(1.0)
,同样的情况会再次发生,基于相同的函数体,会为 Float64
编译一个新的、不同的专用方法。
值类型和单例类型
现在,Julia 具有您可以将数字用作类型的独特功能。这意味着上面概述的调度过程也适用于以下功能:
g(::Type{Val{N}}) where N = N + 1
这有点棘手。请记住,类型本身就是 Julia 中的值:Int isa Type
。
在这里,Val{N}
是针对每个 N
一个所谓的 单例类型 ,只有一个实例,即 Val{N}()
-- 只是像 Int
是一个有很多实例的类型 0
, -1
, 1
, -2
, ....
Type{T}
也是一个单例类型,其单个实例 类型 T
。 Int
是一个 Type{Int}
,而 Val{3}
是一个 Type{Val{3}}
—— 事实上,它们都是它们类型的唯一值。
因此,对于每个 N
,都有一个类型 Val{N}
,是 Type{Val{N}}
的单个实例。因此,g
将针对每个 N
进行分派和编译。这就是我们如何将数字作为类型进行分派。这已经允许优化:
julia> @code_llvm g(Val{1})
define i64 @julia_g_61158(i8**) #0 !dbg !5 {
top:
ret i64 2
}
julia> @code_llvm f(1)
define i64 @julia_f_61076(i64) #0 !dbg !5 {
top:
%1 = shl i64 %0, 2
%2 = or i64 %1, 3
%3 = mul i64 %2, %0
%4 = add i64 %3, 2
ret i64 %4
}
但请记住,它需要在第一次调用时对每个新 N
进行编译。
(如果你不在正文中使用 x
,fkt(::T)
只是 fkt(x::T)
的缩写。)
集成生成函数和值类型
终于到了生成的函数。它们是对上述调度模式的轻微修改:
- 参数的类型已确定(
Int
)
- 参考函数的方法table,找到最匹配的方法
- 方法主体被视为宏,使用
Int
参数类型作为参数调用,如果之前没有这样做的话。生成的表达式被编译成一个方法。
- 编译后的方法在具体参数上进行评估
此模式允许更改分派函数的每种类型的实现。
对于我们的具体设置,我们要分派 Val
代表斐波那契数列参数的类型:
@generated function fib_gen{n}(::Type{Val{n}}, p::Real)
return fib_expr(n, :p)
end
你现在看到你的解释是完全正确的:
in the first call to fib_gen(Val{3}, 0.5)
, the parametric function
fib_gen{Val{3}}(...)
gets compiled and its content is the fully
expanded expression obtained through fib_expr(3, :p)
, i.e. 3*2*1*p
with p
substituted with the input value.
我希望整个故事也已经回答了您列出的所有三个问题:
- 使用
eval
的实现每次都复制递归,加上编译的开销
Val
是将数字提升为类型的技巧,而 Type{T}
仅包含 T
的单例类型——但我希望这些示例足够有用
- 编译时间 不是在执行之前,因为 JIT -- 它是每次方法第一次编译时,因为它被调用了。
首先,我加入评论:你的问题写得很好,很有建设性。
我已经使用 Julia 0.7-beta 重现了您的结果。
- @generated fib_tot(一段代码)和fib_gen(调用fib_expr)的区别
我的 julia 版本结果是相同的:
julia> @btime fib_tot(Val{10},0.5)
0.042 ns (0 allocations: 0 bytes)
1.8144e6
julia> @btime fib_gen(Val{10},0.5)
0.042 ns (0 allocations: 0 bytes)
1.8144e6
有时会将一个函数分解成多个部分 see official doc:performance tips can be useful, however in your peculiar case I do not see why this could be useful. At compile time Julia has everything it needs to optimize fib_tot
. There is a branch if n<=1
however n
is known at "compile time" thanks to the Type{Val{n}}
技巧,并且应该在生成的(专用)代码中毫无问题地删除此分支。
Type{Val{n}}
技巧
为了专门化函数,Julia 推理是根据 参数类型 而不是根据 参数值 .
执行的
例如,不会为每个 n
值生成 foo(n::Int) = ...
的编译版本。您必须定义一个依赖于 n
值的类型才能达到此目标。这正是 Type{Val{n}}
的工作原理:Val{n}
只是一个参数化的空结构:
struct Val{T} end
因此,每个 Val{1}
、Val{2}
、... Val{100}
、... 都是不同的类型。因此,如果 foo 定义为:
foo(::Type{Val{n}}) where {n} = ...
每个 foo(Val{1})
、foo(Val{2})
、... foo(Val{100})
将触发一个专门的 foo 版本(因为参数 type 不同)。
eval(fib_expr(n, 1))
案例
这个
julia> @btime eval(fib_expr(10, :p))
401.651 μs (99 allocations: 6.45 KiB)
1.8144e6
很慢,因为你的表达式每次都被(重新)编译。如果您改用宏(请参阅 phg 答案),则可以避免该问题。
fib
版本
.
julia> @btime fib(10,0.5)
30.778 ns (0 allocations: 0 bytes)
1.8144e6
此fib
函数只有一个编译版本。因此,它必须包含所有运行时分支测试等...这解释了它有多慢。
请注意:
foo{n}(::Type{Val{n}})
已弃用的语法
foo{n}(::Type{Val{n}})
语法已弃用,新语法是 foo(::Type{Val{n}}) where {n}
。您可以阅读 Julia doc, parametric methods 了解更多详情。
我的 Julia 版本:
julia> versioninfo()
Julia Version 0.7.0-beta.0
Commit f41b1ecaec (2018-06-24 01:32 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Xeon(R) CPU E5-2603 v3 @ 1.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-6.0.0 (ORCJIT, haswell)
遵循
目标是使用表达式和生成的函数优化递归函数(有关具体示例,您可以查看上面提供的 link 中回答的问题)。
考虑以下修改后的斐波那契函数,我想在其中计算斐波那契数列直到 n
并将其乘以一个数字 p
。
直接的递归实现是
function fib(n::Integer, p::Real)
if n <= 1
return 1 * p
else
return n * fib(n-1, p)
end
end
作为第一步,我可以定义一个函数,其中 returns 一个 表达式 而不是计算值
function fib_expr(n::Integer, p::Symbol)
if n <= 1
return :(1 * $p)
else
return :($n * $(fib_expr(n-1, p)))
end
end
哪个,例如returns 类似
julia> ex = fib_expr(3, :myp)
:(3 * (2 * (1myp)))
通过这种方式,我得到了一个 完全展开 的表达式,它取决于分配给符号 myp
的值。通过这种方式,我不再看到递归,基本上我是 元编程 :我创建了一个创建另一个 "function" 的函数(尽管在这种情况下我们称之为表达式)。
我现在可以设置 myp = 0.5
并调用 eval(ex)
来计算结果。
但是,这比第一种方法慢。
不过我能做的是按以下方式生成参数函数
@generated function fib_gen{n}(::Type{Val{n}}, p::Real)
return fib_expr(n, :p)
end
神奇的是,调用 fib_gen(Val{3}, 0.5)
可以完成任务,而且速度快得令人难以置信。
所以,这是怎么回事?
据我了解,在第一次调用fib_gen(Val{3}, 0.5)
时,参数函数fib_gen{Val{3}}(...)
被编译,其内容是通过fib_expr(3, :p)
获得的完全扩展表达式,即3*2*1*p
并用 p
替换为输入值。
它之所以这么快,是因为 fib_gen
基本上只是一系列乘法,而原来的 fib
必须在堆栈上分配每个递归调用使其变慢,我说的对吗?
为了给出一些数字,这是我的简短基准 using BenchmarkTools
。
julia> @benchmark fib(10, 0.5)
...
mean time: 26.373 ns
...
julia> p = 0.5
0.5
julia> @benchmark eval(fib_expr(10, :p))
...
mean time: 177.906 μs
...
julia> @benchmark fib_gen(Val{10}, 0.5)
...
mean time: 2.046 ns
...
我有很多问题:
- 为什么第二种情况这么慢?
::Type{Val{n}}
到底是什么意思? (我从上面的答案 link 中复制了它)- 由于 JIT 编译器,有时我会迷失在 编译时 和 运行 时 发生的事情],这里就是这种情况...
此外,我尝试根据
将fib_expr
和fib_gen
组合在一个函数中
@generated function fib_tot{n}(::Type{Val{n}}, p::Real)
if n <= 1
return :(1 * p)
else
return :(n * fib_tot(Val{n-1}, p))
end
end
但是速度很慢
julia> @benchmark fib_tot(Val{10}, 0.5)
...
mean time: 4.601 μs
...
我在这里做错了什么?甚至可以将 fib_expr
和 fib_gen
组合在一个函数中吗?
我意识到这更像是一本专着而不是一个问题,但是,即使我读了几次 metaprogramming 部分,我也很难掌握所有内容,尤其是应用示例就像这个。
专着回应:
元编程基础知识
首先使用 "normal" 宏会更容易。我将放宽您使用的定义:
function fib_expr(n::Integer, p)
if n <= 1
return :(1 * $p)
else
return :($n * $(fib_expr(n-1, p)))
end
end
这允许传递的不仅仅是 p
的符号,例如整数文字或整个表达式。鉴于此,我们可以为相同的功能定义一个宏:
macro fib_macro(n::Integer, p)
fib_expr(n, p)
end
现在,如果在代码中的任何地方使用 @fib_macro 45 1
,在编译时它将首先被一个长嵌套表达式替换
:(45 * (44 * ... * (1 * 1)) ... )
然后正常编译为常量。
真的,这就是宏的全部内容。在编译期间替换语法;通过递归,这可以是编译和计算表达式函数之间任意长的改变。对于本质上不变的东西,但写起来很乏味,它非常有用:一个很好的例子是 Base.Math.@evalpoly.
运行时求值?
但是它有一个问题,你不能检查只有在运行时才知道的值:你不能实现 fib(n) = @fib_macro n 1
,因为在编译时,n
是一个代表参数的符号,而不是您可以发送的号码。
下一个最佳解决方案是使用
fib_eval(n::Integer) = eval(fib_expr(n, 1))
有效,但每次调用时都会重复编译过程——这比原始函数的开销大得多,因为现在在运行时,我们执行表达式树上的整个递归,然后在结果上调用编译器。不好。
方法调度&编译
所以我们需要一种混合运行时和编译时的方法。输入 @generated
函数。这些将在运行时分派 type,然后像定义函数体的宏一样工作。
首先是类型调度。如果我们有
f(x) = x + 1
并且有一个函数调用f(1)
,大约会发生以下情况:
- 参数的类型已确定(
Int
) - 参考函数的方法table,找到最匹配的方法
- 方法体是为特定的
Int
参数类型编译的,如果之前没有这样做的话 - 编译后的方法在具体参数上进行评估
如果我们随后输入 f(1.0)
,同样的情况会再次发生,基于相同的函数体,会为 Float64
编译一个新的、不同的专用方法。
值类型和单例类型
现在,Julia 具有您可以将数字用作类型的独特功能。这意味着上面概述的调度过程也适用于以下功能:
g(::Type{Val{N}}) where N = N + 1
这有点棘手。请记住,类型本身就是 Julia 中的值:Int isa Type
。
在这里,Val{N}
是针对每个 N
一个所谓的 单例类型 ,只有一个实例,即 Val{N}()
-- 只是像 Int
是一个有很多实例的类型 0
, -1
, 1
, -2
, ....
Type{T}
也是一个单例类型,其单个实例 类型 T
。 Int
是一个 Type{Int}
,而 Val{3}
是一个 Type{Val{3}}
—— 事实上,它们都是它们类型的唯一值。
因此,对于每个 N
,都有一个类型 Val{N}
,是 Type{Val{N}}
的单个实例。因此,g
将针对每个 N
进行分派和编译。这就是我们如何将数字作为类型进行分派。这已经允许优化:
julia> @code_llvm g(Val{1})
define i64 @julia_g_61158(i8**) #0 !dbg !5 {
top:
ret i64 2
}
julia> @code_llvm f(1)
define i64 @julia_f_61076(i64) #0 !dbg !5 {
top:
%1 = shl i64 %0, 2
%2 = or i64 %1, 3
%3 = mul i64 %2, %0
%4 = add i64 %3, 2
ret i64 %4
}
但请记住,它需要在第一次调用时对每个新 N
进行编译。
(如果你不在正文中使用 x
,fkt(::T)
只是 fkt(x::T)
的缩写。)
集成生成函数和值类型
终于到了生成的函数。它们是对上述调度模式的轻微修改:
- 参数的类型已确定(
Int
) - 参考函数的方法table,找到最匹配的方法
- 方法主体被视为宏,使用
Int
参数类型作为参数调用,如果之前没有这样做的话。生成的表达式被编译成一个方法。 - 编译后的方法在具体参数上进行评估
此模式允许更改分派函数的每种类型的实现。
对于我们的具体设置,我们要分派 Val
代表斐波那契数列参数的类型:
@generated function fib_gen{n}(::Type{Val{n}}, p::Real)
return fib_expr(n, :p)
end
你现在看到你的解释是完全正确的:
in the first call to
fib_gen(Val{3}, 0.5)
, the parametric functionfib_gen{Val{3}}(...)
gets compiled and its content is the fully expanded expression obtained throughfib_expr(3, :p)
, i.e.3*2*1*p
withp
substituted with the input value.
我希望整个故事也已经回答了您列出的所有三个问题:
- 使用
eval
的实现每次都复制递归,加上编译的开销 Val
是将数字提升为类型的技巧,而Type{T}
仅包含T
的单例类型——但我希望这些示例足够有用- 编译时间 不是在执行之前,因为 JIT -- 它是每次方法第一次编译时,因为它被调用了。
首先,我加入评论:你的问题写得很好,很有建设性。
我已经使用 Julia 0.7-beta 重现了您的结果。
- @generated fib_tot(一段代码)和fib_gen(调用fib_expr)的区别
我的 julia 版本结果是相同的:
julia> @btime fib_tot(Val{10},0.5)
0.042 ns (0 allocations: 0 bytes)
1.8144e6
julia> @btime fib_gen(Val{10},0.5)
0.042 ns (0 allocations: 0 bytes)
1.8144e6
有时会将一个函数分解成多个部分 see official doc:performance tips can be useful, however in your peculiar case I do not see why this could be useful. At compile time Julia has everything it needs to optimize fib_tot
. There is a branch if n<=1
however n
is known at "compile time" thanks to the Type{Val{n}}
技巧,并且应该在生成的(专用)代码中毫无问题地删除此分支。
Type{Val{n}}
技巧
为了专门化函数,Julia 推理是根据 参数类型 而不是根据 参数值 .
执行的例如,不会为每个 n
值生成 foo(n::Int) = ...
的编译版本。您必须定义一个依赖于 n
值的类型才能达到此目标。这正是 Type{Val{n}}
的工作原理:Val{n}
只是一个参数化的空结构:
struct Val{T} end
因此,每个 Val{1}
、Val{2}
、... Val{100}
、... 都是不同的类型。因此,如果 foo 定义为:
foo(::Type{Val{n}}) where {n} = ...
每个 foo(Val{1})
、foo(Val{2})
、... foo(Val{100})
将触发一个专门的 foo 版本(因为参数 type 不同)。
eval(fib_expr(n, 1))
案例
这个
julia> @btime eval(fib_expr(10, :p))
401.651 μs (99 allocations: 6.45 KiB)
1.8144e6
很慢,因为你的表达式每次都被(重新)编译。如果您改用宏(请参阅 phg 答案),则可以避免该问题。
fib
版本
.
julia> @btime fib(10,0.5)
30.778 ns (0 allocations: 0 bytes)
1.8144e6
此fib
函数只有一个编译版本。因此,它必须包含所有运行时分支测试等...这解释了它有多慢。
请注意:
foo{n}(::Type{Val{n}})
已弃用的语法
foo{n}(::Type{Val{n}})
语法已弃用,新语法是 foo(::Type{Val{n}}) where {n}
。您可以阅读 Julia doc, parametric methods 了解更多详情。
我的 Julia 版本:
julia> versioninfo()
Julia Version 0.7.0-beta.0
Commit f41b1ecaec (2018-06-24 01:32 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Xeon(R) CPU E5-2603 v3 @ 1.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-6.0.0 (ORCJIT, haswell)