如何避免自定义 Julia 迭代器中的内存分配?

How to avoid memory allocations in custom Julia iterators?

考虑以下 Julia“复合”迭代器:它合并了两个迭代器,ab, 假定每个都根据 order 排序为单个有序 顺序:

struct MergeSorted{T,A,B,O}
    a::A
    b::B
    order::O

    MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
        new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

@inline function Base.iterate(self::MergeSorted{T}, 
                      state=(iterate(self.a), iterate(self.b))) where T
    a_result, b_result = state
    if b_result === nothing
        a_result === nothing && return nothing
        a_curr, a_state = a_result
        return T(a_curr), (iterate(self.a, a_state), b_result)
    end

    b_curr, b_state = b_result
    if a_result !== nothing
        a_curr, a_state = a_result
        Base.Order.lt(self.order, a_curr, b_curr) &&
            return T(a_curr), (iterate(self.a, a_state), b_result)
    end
    return T(b_curr), (a_result, iterate(self.b, b_state))
end

此代码有效,但类型不稳定,因为 Julia 迭代工具本身就是如此。对于大多数情况,编译器可以自动解决这个问题,但是,在这里它不起作用:以下测试代码说明创建了临时文件:

>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)

记下分配计数。

除了修改代码并希望编译器能够优化类型歧义之外,还有什么方法可以有效地调试此类情况?有谁知道在这种特定情况下有没有创建临时文件的解决方案?

如何诊断问题?

答案:使用@code_warntype

运行:

julia> @code_warntype iterate(x, iterate(x)[2])
Variables
  #self#::Core.Const(iterate)
  self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering}
  state::Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
  @_4::Int64
  @_5::Int64
  @_6::Union{}
  @_7::Int64
  b_state::Int64
  b_curr::Int64
  a_state::Int64
  a_curr::Int64
  b_result::Tuple{Int64, Int64}
  a_result::Tuple{Int64, Int64}

Body::Tuple{Int64, Any}
1 ─       nothing
│         Core.NewvarNode(:(@_4))
│         Core.NewvarNode(:(@_5))
│         Core.NewvarNode(:(@_6))
│         Core.NewvarNode(:(b_state))
│         Core.NewvarNode(:(b_curr))
│         Core.NewvarNode(:(a_state))
│         Core.NewvarNode(:(a_curr))
│   %9  = Base.indexed_iterate(state, 1)::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(2)])
│         (a_result = Core.getfield(%9, 1))
│         (@_7 = Core.getfield(%9, 2))
│   %12 = Base.indexed_iterate(state, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(3)])
│         (b_result = Core.getfield(%12, 1))
│   %14 = (b_result === Main.nothing)::Core.Const(false)
└──       goto #3 if not %14
2 ─       Core.Const(:(a_result === Main.nothing))
│         Core.Const(:(%16))
│         Core.Const(:(return Main.nothing))
│         Core.Const(:(Base.indexed_iterate(a_result, 1)))
│         Core.Const(:(a_curr = Core.getfield(%19, 1)))
│         Core.Const(:(@_6 = Core.getfield(%19, 2)))
│         Core.Const(:(Base.indexed_iterate(a_result, 2, @_6)))
│         Core.Const(:(a_state = Core.getfield(%22, 1)))
│         Core.Const(:(($(Expr(:static_parameter, 1)))(a_curr)))
│         Core.Const(:(Base.getproperty(self, :a)))
│         Core.Const(:(Main.iterate(%25, a_state)))
│         Core.Const(:(Core.tuple(%26, b_result)))
│         Core.Const(:(Core.tuple(%24, %27)))
└──       Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (b_curr = Core.getfield(%30, 1))
│         (@_5 = Core.getfield(%30, 2))
│   %33 = Base.indexed_iterate(b_result, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (b_state = Core.getfield(%33, 1))
│   %35 = (a_result !== Main.nothing)::Core.Const(true)
└──       goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (a_curr = Core.getfield(%37, 1))
│         (@_4 = Core.getfield(%37, 2))
│   %40 = Base.indexed_iterate(a_result, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (a_state = Core.getfield(%40, 1))
│   %42 = Base.Order::Core.Const(Base.Order)
│   %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│   %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│   %45 = a_curr::Int64
│   %46 = (%43)(%44, %45, b_curr)::Bool
└──       goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│   %49 = Base.getproperty(self, :a)::Vector{Int64}
│   %50 = Main.iterate(%49, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %51 = Core.tuple(%50, b_result)::Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}
│   %52 = Core.tuple(%48, %51)::Tuple{Int64, Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}}
└──       return %52
6 ┄ %54 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│   %55 = a_result::Tuple{Int64, Int64}
│   %56 = Base.getproperty(self, :b)::Vector{Int64}
│   %57 = Main.iterate(%56, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %58 = Core.tuple(%55, %57)::Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}
│   %59 = Core.tuple(%54, %58)::Tuple{Int64, Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}}
└──       return %59

并且您看到 return 值的类型太多,因此 Julia 放弃了对它们进行特化(并假设 return 类型的第二个元素是 Any)。

如何解决这个问题?

答案:减少iterate的return类型选项的个数。

这是一个快速写的(我不认为它是最简洁的并且没有广泛测试它所以可能存在一些错误,但它足够简单,可以使用您的代码快速编写以展示如何处理你的问题;请注意,当其中一个集合为空时我使用特殊分支,因为这样应该更快地迭代一个集合):

struct MergeSorted{T,A,B,O,F1,F2}
    a::A
    b::B
    order::O
    fa::F1
    fb::F2
    function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
        fa, fb = iterate(a), iterate(b)
        F1 = typeof(fa)
        F2 = typeof(fb)
        new{promote_type(eltype(A),eltype(B)),A,B,O,F1,F2}(a, b, order, fa, fb)
    end
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

struct State{Ta, Tb}
    a::Union{Nothing, Ta}
    b::Union{Nothing, Tb}
end

function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,Nothing}) where {T,A,B,O}
    return nothing
end

function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}) where {T,A,B,O,F1}
    return self.fa
end

function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}, state) where {T,A,B,O,F1}
    return iterate(self.a, state)
end

function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}) where {T,A,B,O,F2}
    return self.fb
end

function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}, state) where {T,A,B,O,F2}
    return iterate(self.b, state)
end

@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}) where {T,A,B,O,F1,F2}
    a_result, b_result = self.fa, self.fb
    return iterate(self, State{F1,F2}(a_result, b_result))
end

@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}, 
    state::State{F1,F2}) where {T,A,B,O,F1,F2}
    a_result, b_result = state.a, state.b

    if b_result === nothing
        a_result === nothing && return nothing
        a_curr, a_state = a_result
        return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
    end

    b_curr, b_state = b_result
    if a_result !== nothing
        a_curr, a_state = a_result
        Base.Order.lt(self.order, a_curr, b_curr) &&
            return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
    end
    return T(b_curr), State{F1,F2}(a_result, iterate(self.b, b_state))
end

现在你有:

julia> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);

julia> sum(x)
269

julia> @allocated sum(x)
0

julia> @code_warntype iterate(x, iterate(x)[2])
Variables
  #self#::Core.Const(iterate)
  self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering, Tuple{Int64, Int64}, Tuple{Int64, Int64}}
  state::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
  @_4::Int64
  @_5::Int64
  @_6::Int64
  b_state::Int64
  b_curr::Int64
  a_state::Int64
  a_curr::Int64
  b_result::Union{Nothing, Tuple{Int64, Int64}}
  a_result::Union{Nothing, Tuple{Int64, Int64}}

Body::Union{Nothing, Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}}
1 ─       nothing
│         Core.NewvarNode(:(@_4))
│         Core.NewvarNode(:(@_5))
│         Core.NewvarNode(:(@_6))
│         Core.NewvarNode(:(b_state))
│         Core.NewvarNode(:(b_curr))
│         Core.NewvarNode(:(a_state))
│         Core.NewvarNode(:(a_curr))
│   %9  = Base.getproperty(state, :a)::Union{Nothing, Tuple{Int64, Int64}}
│   %10 = Base.getproperty(state, :b)::Union{Nothing, Tuple{Int64, Int64}}
│         (a_result = %9)
│         (b_result = %10)
│   %13 = (b_result === Main.nothing)::Bool
└──       goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└──       goto #4 if not %15
3 ─       return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (a_curr = Core.getfield(%18, 1))
│         (@_6 = Core.getfield(%18, 2))
│   %21 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (a_state = Core.getfield(%21, 1))
│   %23 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│   %24 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│   %25 = Base.getproperty(self, :a)::Vector{Int64}
│   %26 = Main.iterate(%25, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %27 = (%24)(%26, b_result::Core.Const(nothing))::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│   %28 = Core.tuple(%23, %27)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└──       return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (b_curr = Core.getfield(%30, 1))
│         (@_5 = Core.getfield(%30, 2))
│   %33 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (b_state = Core.getfield(%33, 1))
│   %35 = (a_result !== Main.nothing)::Bool
└──       goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (a_curr = Core.getfield(%37, 1))
│         (@_4 = Core.getfield(%37, 2))
│   %40 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (a_state = Core.getfield(%40, 1))
│   %42 = Base.Order::Core.Const(Base.Order)
│   %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│   %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│   %45 = a_curr::Int64
│   %46 = (%43)(%44, %45, b_curr)::Bool
└──       goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│   %49 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│   %50 = Base.getproperty(self, :a)::Vector{Int64}
│   %51 = Main.iterate(%50, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %52 = (%49)(%51, b_result::Tuple{Int64, Int64})::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│   %53 = Core.tuple(%48, %52)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└──       return %53
8 ┄ %55 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│   %56 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│   %57 = a_result::Union{Nothing, Tuple{Int64, Int64}}
│   %58 = Base.getproperty(self, :b)::Vector{Int64}
│   %59 = Main.iterate(%58, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %60 = (%56)(%57, %59)::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│   %61 = Core.tuple(%55, %60)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└──       return %61

编辑:现在我意识到我的实现并不完全正确,因为它假设 iterate 的 return 值如果不是 nothing 是稳定类型(它不一定是)。但如果它的类型不稳定,那么编译器必须分配。因此,一个完全正确的解决方案将首先检查 iterate 是否类型稳定。如果是 - 使用我的解决方案,如果不是 - 使用例如你的解决方案。