在 Julia 中将函数参数更改为关键字似乎会引入类型不稳定性

Changing function arguments to keywords in Julia seems to introduce type instability

我有一个程序,其中 main() 函数有四个参数。当我 运行 @code_warntype 上的功能似乎没有什么不妥。所有的变量都有指定的类型,并且没有 UNION 或其他明显警告标志的实例。

抱歉,程序很长,但我不确定如何在保留问题的情况下缩短它:

function main(n::Int, dice::Int=6, start::Int=1, modal::Int=3) ::Tuple{String, Vector{String}, Vector{Float64}}
    board = String["GO", "A1", "CC1", "A2", "T1", "R1", "B1", "CH1", "B2", "B3",
        "JAIL", "C1", "U1", "C2", "C3", "R2", "D1", "CC2", "D2", "D3",
        "FP", "E1", "CH2", "E2", "E3", "R3", "F1", "F2", "U2", "F3",
        "G2J", "G1", "G2", "CC3", "G3", "R4", "CH3", "H1", "T2", "H2"]
    cc_cards = shuffle(collect(1:16))
    ch_cards = shuffle(collect(1:16))
    function take_cc_card(square::Int, cards::Vector{Int})::Tuple{Int, Vector{Int}}
        if cards[1] == 1
            square = findfirst(board, "GO")
        elseif cards[1] == 2
            square = findfirst(board, "JAIL")
        end
        p = pop!(cards)
        unshift!(cards, p)
        return square, cards
    end
    function take_ch_card(square::Int, cards::Vector{Int})::Tuple{Int, Vector{Int}}
        if cards[1] == 1
            square = findfirst(board, "GO")
        elseif cards[1] == 2
            square = findfirst(board, "JAIL")
        elseif cards[1] == 3
            square = findfirst(board, "C1")
        elseif cards[1] == 4
            square = findfirst(board, "E3")
        elseif cards[1] == 5
            square = findfirst(board, "H2")
        elseif cards[1] == 6
            square = findfirst(board, "R1")
        elseif cards[1] == 7 || cards[1] == 8
            if board[square] == "CH1"
                square = findfirst(board, "R2")
            elseif board[square] == "CH2"
                square = findfirst(board, "R3")
            elseif board[square] == "CH3"
                square = findfirst(board, "R1")
            end
        elseif cards[1] == 9
            if board[square] == "CH1"
                square = findfirst(board, "U1")
            elseif board[square] == "CH2"
                square = findfirst(board, "U2")
            elseif board[square] == "CH3"
                square = findfirst(board, "U1")
            end
        elseif cards[1] == 10
            square = (square - 3) % 40 + ((square - 3 % 40 == 0 ? 40 : 0))
        end
        p = pop!(cards)
        unshift!(cards, p)
        return square, cards
    end
    result = zeros(Int, 40)
    consec_doubles = 0
    square = 1
    for i = 1:n
        throw_1 = rand(collect(1:dice))
        throw_2 = rand(collect(1:dice))
        if throw_1 == throw_2
            consec_doubles += 1
        else
            consec_doubles = 0
        end
        if consec_doubles != 3
            move = throw_1 + throw_2
            square = (square + move) % 40 +((square + move) % 40 == 0 ? 40 : 0)
            if board[square] == "G2J"
                square = findfirst(board, "JAIL")
            elseif board[square][1:2] == "CC"
                square, cc_cards = take_cc_card(square, cc_cards)
            elseif board[square][1:2] == "CH"
                square, ch_cards = take_ch_card(square, ch_cards)
                if board[square][1:2] == "CC"
                    square, cc_cards = take_cc_card(square, cc_cards)
                end
            end
        else
            square = findfirst(board, "JAIL")
            consec_doubles = 0
        end
        if i >= start
            result[square] += 1
        end
    end
    result_tuple = Vector{Tuple{Float64, Int}}()
    for i = 1:40
        percent = result[i] * 100 / sum(result)
        push!(result_tuple, (percent, i))
    end
    sort!(result_tuple, lt = (x, y) -> isless(x[1], y[1]), rev=true)
    modal_squares = Vector{String}()
    modal_string = ""
    modal_percents = Vector{Float64}()
    for i = 1:modal
        push!(modal_squares, board[result_tuple[i][2]])
        push!(modal_percents, result_tuple[i][1])
        k = result_tuple[i][2] - 1
        modal_string *= (k < 10 ? ("0" * string(k)) : string(k))
    end
    return modal_string, modal_squares, modal_percents
end

@code_warntype main(1_000_000, 4, 101, 5)

但是,当我通过在第一个参数后插入分号而不是逗号将最后三个参数更改为关键字时...

function main(n::Int; dice::Int=6, start::Int=1, modal::Int=3) ::Tuple{String, Vector{String}, Vector{Float64}}

...我似乎 运行 遇到类型稳定性问题。

@code_warntype main(1_000_000, dice=4, start=101, modal=5)

当我 运行 @code_warntype.

奇怪的是,这似乎并没有带来性能上的影响,因为平均三个基准测试 'argument' 版本 运行s 在 431.594 毫秒和 'keyword' 版本运行s 在 413.149 毫秒内。但是,我很想知道:

(a) 为什么会这样;

(b) 一般来说,ANY 类型的临时变量的出现是否值得关注;和

(c) 作为一般规则,从性能的角度来看,使用关键字而不是普通函数参数是否有任何优势。

这是我对这三个问题的看法。在答案中,我假设 Julia 0.6.3 除非我在 post.

的末尾明确声明我指的是 Julia 0.7

(a) 带有 Any 变量的代码是负责处理关键字参数的代码的一部分(例如,确保函数签名允许传递的关键字参数)。原因是关键字参数在函数内部作为 Vector{Any} 接收。该向量包含元组 ([argument name], [argument value])。 实际的 "work" 函数确实发生在这部分之后 Any 变量。

通过调用对比可以看出这一点:

@code_warntype main(1_000_000, dice=4, start=101, modal=5)

@code_warntype main(1_000_000)

对于带有关键字参数的函数。第二次调用只有上面第一次调用生成的报告的最后一行,所有其他的都负责处理传递的关键字参数。

(b) 作为一般规则,这当然是一个问题,但在这种情况下,这无济于事。带有 Any 的变量包含有关关键字参数名称的信息。

(c) 一般来说,您可以假设位置参数并不比关键字参数慢,但可以更快。这是一个 MWE(实际上如果你 运行 @code_warntype f(a=10) 你也会看到这个 Any 变量:

julia> using BenchmarkTools

julia> f(;a::Int=1) = a+1
f (generic function with 1 method)

julia> g(a::Int=1) = a+1
g (generic function with 2 methods)

julia> @benchmark f()
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.974 ns (0.00% GC)
  maximum time:     14.463 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark f(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     52.994 ns (0.00% GC)
  median time:      54.413 ns (0.00% GC)
  mean time:        65.207 ns (10.65% GC)
  maximum time:     3.466 μs (94.78% GC)
  --------------
  samples:          10000
  evals/sample:     986

julia> @benchmark g()
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.954 ns (0.00% GC)
  maximum time:     13.062 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark g(10)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.949 ns (0.00% GC)
  maximum time:     13.063 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

现在你可以看到实际上关键字参数的惩罚是在它被传递的时候(当你在 @code_warntype 中有 Any 变量时就是这种情况,因为 Julia 必须做更多的工作然后)。请注意,惩罚很小,并且在做很少工作的函数中是可见的。对于进行大量计算的函数,大部分时间都可以忽略它。

另外请注意,如果您不指定关键字参数的类型,则显式传递关键字参数值时惩罚会大得多,因为 Julia 不会分派关键字参数类型(您也可以 运行 @code_warntype见证这一点):

julia> h(;a=1) = a+1
h (generic function with 1 method)

julia> @benchmark h()
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.960 ns (0.00% GC)
  maximum time:     13.996 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark h(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     75.433 ns (0.00% GC)
  median time:      77.355 ns (0.00% GC)
  mean time:        89.037 ns (7.87% GC)
  maximum time:     2.128 μs (89.73% GC)
  --------------
  samples:          10000
  evals/sample:     971

在 Julia 0.7 中,关键字参数被接收为 Base.Iterator.Pairs 持有 NamedTuple 因此 Julia 在编译时知道传递参数的类型。这意味着使用关键字参数比 Julia 0.6.3 更快(但同样 - 你不应该期望它们比位置参数更快)。你可以看到这个 buy 运行ning 类似的基准测试(我只是改变了函数的作用,以便为 Julia 编译器提供更多的工作)如上所述,但在 Julia 0.7 下(你也可以看看 @code_warntype 在这些函数上看到类型推断在 Julia 0.7 中工作得更好):

julia> using BenchmarkTools

julia> f(;a::Int=1) = [a]
f (generic function with 1 method)

julia> g(a::Int=1) = [a]
g (generic function with 2 methods)

julia> h(;a=1) = [a]
h (generic function with 1 method)

julia> @benchmark f()
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.523 ns (0.00% GC)
  mean time:        50.576 ns (22.80% GC)
  maximum time:     53.465 μs (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark f(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.057 ns (0.00% GC)
  mean time:        50.739 ns (22.83% GC)
  maximum time:     55.303 μs (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark g()
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.523 ns (0.00% GC)
  mean time:        50.529 ns (22.77% GC)
  maximum time:     54.501 μs (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark g(10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.523 ns (0.00% GC)
  mean time:        50.899 ns (23.27% GC)
  maximum time:     56.246 μs (99.90% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark h()
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.257 ns (0.00% GC)
  median time:      34.057 ns (0.00% GC)
  mean time:        50.924 ns (22.87% GC)
  maximum time:     55.724 μs (99.88% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark h(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.057 ns (0.00% GC)
  mean time:        50.864 ns (22.60% GC)
  maximum time:     53.389 μs (99.83% GC)
  --------------
  samples:          10000
  evals/sample:     1000