如何避免自定义 Julia 迭代器中的内存分配?
How to avoid memory allocations in custom Julia iterators?
考虑以下 Julia“复合”迭代器:它合并了两个迭代器,a
和 b
,
假定每个都根据 order
排序为单个有序
顺序:
struct MergeSorted{T,A,B,O}
a::A
b::B
order::O
MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
@inline function Base.iterate(self::MergeSorted{T},
state=(iterate(self.a), iterate(self.b))) where T
a_result, b_result = state
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), (iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), (iterate(self.a, a_state), b_result)
end
return T(b_curr), (a_result, iterate(self.b, b_state))
end
此代码有效,但类型不稳定,因为 Julia 迭代工具本身就是如此。对于大多数情况,编译器可以自动解决这个问题,但是,在这里它不起作用:以下测试代码说明创建了临时文件:
>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)
记下分配计数。
除了修改代码并希望编译器能够优化类型歧义之外,还有什么方法可以有效地调试此类情况?有谁知道在这种特定情况下有没有创建临时文件的解决方案?
如何诊断问题?
答案:使用@code_warntype
运行:
julia> @code_warntype iterate(x, iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering}
state::Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
@_4::Int64
@_5::Int64
@_6::Union{}
@_7::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Tuple{Int64, Int64}
a_result::Tuple{Int64, Int64}
Body::Tuple{Int64, Any}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.indexed_iterate(state, 1)::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(2)])
│ (a_result = Core.getfield(%9, 1))
│ (@_7 = Core.getfield(%9, 2))
│ %12 = Base.indexed_iterate(state, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(3)])
│ (b_result = Core.getfield(%12, 1))
│ %14 = (b_result === Main.nothing)::Core.Const(false)
└── goto #3 if not %14
2 ─ Core.Const(:(a_result === Main.nothing))
│ Core.Const(:(%16))
│ Core.Const(:(return Main.nothing))
│ Core.Const(:(Base.indexed_iterate(a_result, 1)))
│ Core.Const(:(a_curr = Core.getfield(%19, 1)))
│ Core.Const(:(@_6 = Core.getfield(%19, 2)))
│ Core.Const(:(Base.indexed_iterate(a_result, 2, @_6)))
│ Core.Const(:(a_state = Core.getfield(%22, 1)))
│ Core.Const(:(($(Expr(:static_parameter, 1)))(a_curr)))
│ Core.Const(:(Base.getproperty(self, :a)))
│ Core.Const(:(Main.iterate(%25, a_state)))
│ Core.Const(:(Core.tuple(%26, b_result)))
│ Core.Const(:(Core.tuple(%24, %27)))
└── Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (b_curr = Core.getfield(%30, 1))
│ (@_5 = Core.getfield(%30, 2))
│ %33 = Base.indexed_iterate(b_result, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (b_state = Core.getfield(%33, 1))
│ %35 = (a_result !== Main.nothing)::Core.Const(true)
└── goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%37, 1))
│ (@_4 = Core.getfield(%37, 2))
│ %40 = Base.indexed_iterate(a_result, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%40, 1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44, %45, b_curr)::Bool
└── goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %49 = Base.getproperty(self, :a)::Vector{Int64}
│ %50 = Main.iterate(%49, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %51 = Core.tuple(%50, b_result)::Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}
│ %52 = Core.tuple(%48, %51)::Tuple{Int64, Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}}
└── return %52
6 ┄ %54 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│ %55 = a_result::Tuple{Int64, Int64}
│ %56 = Base.getproperty(self, :b)::Vector{Int64}
│ %57 = Main.iterate(%56, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %58 = Core.tuple(%55, %57)::Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}
│ %59 = Core.tuple(%54, %58)::Tuple{Int64, Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}}
└── return %59
并且您看到 return 值的类型太多,因此 Julia 放弃了对它们进行特化(并假设 return 类型的第二个元素是 Any
)。
如何解决这个问题?
答案:减少iterate
的return类型选项的个数。
这是一个快速写的(我不认为它是最简洁的并且没有广泛测试它所以可能存在一些错误,但它足够简单,可以使用您的代码快速编写以展示如何处理你的问题;请注意,当其中一个集合为空时我使用特殊分支,因为这样应该更快地迭代一个集合):
struct MergeSorted{T,A,B,O,F1,F2}
a::A
b::B
order::O
fa::F1
fb::F2
function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
fa, fb = iterate(a), iterate(b)
F1 = typeof(fa)
F2 = typeof(fb)
new{promote_type(eltype(A),eltype(B)),A,B,O,F1,F2}(a, b, order, fa, fb)
end
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
struct State{Ta, Tb}
a::Union{Nothing, Ta}
b::Union{Nothing, Tb}
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,Nothing}) where {T,A,B,O}
return nothing
end
function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}) where {T,A,B,O,F1}
return self.fa
end
function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}, state) where {T,A,B,O,F1}
return iterate(self.a, state)
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}) where {T,A,B,O,F2}
return self.fb
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}, state) where {T,A,B,O,F2}
return iterate(self.b, state)
end
@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}) where {T,A,B,O,F1,F2}
a_result, b_result = self.fa, self.fb
return iterate(self, State{F1,F2}(a_result, b_result))
end
@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2},
state::State{F1,F2}) where {T,A,B,O,F1,F2}
a_result, b_result = state.a, state.b
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
end
return T(b_curr), State{F1,F2}(a_result, iterate(self.b, b_state))
end
现在你有:
julia> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
julia> sum(x)
269
julia> @allocated sum(x)
0
julia> @code_warntype iterate(x, iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering, Tuple{Int64, Int64}, Tuple{Int64, Int64}}
state::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
@_4::Int64
@_5::Int64
@_6::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Union{Nothing, Tuple{Int64, Int64}}
a_result::Union{Nothing, Tuple{Int64, Int64}}
Body::Union{Nothing, Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.getproperty(state, :a)::Union{Nothing, Tuple{Int64, Int64}}
│ %10 = Base.getproperty(state, :b)::Union{Nothing, Tuple{Int64, Int64}}
│ (a_result = %9)
│ (b_result = %10)
│ %13 = (b_result === Main.nothing)::Bool
└── goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└── goto #4 if not %15
3 ─ return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%18, 1))
│ (@_6 = Core.getfield(%18, 2))
│ %21 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%21, 1))
│ %23 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %24 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %25 = Base.getproperty(self, :a)::Vector{Int64}
│ %26 = Main.iterate(%25, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %27 = (%24)(%26, b_result::Core.Const(nothing))::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %28 = Core.tuple(%23, %27)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (b_curr = Core.getfield(%30, 1))
│ (@_5 = Core.getfield(%30, 2))
│ %33 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (b_state = Core.getfield(%33, 1))
│ %35 = (a_result !== Main.nothing)::Bool
└── goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%37, 1))
│ (@_4 = Core.getfield(%37, 2))
│ %40 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%40, 1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44, %45, b_curr)::Bool
└── goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %49 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %50 = Base.getproperty(self, :a)::Vector{Int64}
│ %51 = Main.iterate(%50, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %52 = (%49)(%51, b_result::Tuple{Int64, Int64})::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %53 = Core.tuple(%48, %52)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %53
8 ┄ %55 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│ %56 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %57 = a_result::Union{Nothing, Tuple{Int64, Int64}}
│ %58 = Base.getproperty(self, :b)::Vector{Int64}
│ %59 = Main.iterate(%58, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %60 = (%56)(%57, %59)::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %61 = Core.tuple(%55, %60)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %61
编辑:现在我意识到我的实现并不完全正确,因为它假设 iterate
的 return 值如果不是 nothing
是稳定类型(它不一定是)。但如果它的类型不稳定,那么编译器必须分配。因此,一个完全正确的解决方案将首先检查 iterate 是否类型稳定。如果是 - 使用我的解决方案,如果不是 - 使用例如你的解决方案。
考虑以下 Julia“复合”迭代器:它合并了两个迭代器,a
和 b
,
假定每个都根据 order
排序为单个有序
顺序:
struct MergeSorted{T,A,B,O}
a::A
b::B
order::O
MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
@inline function Base.iterate(self::MergeSorted{T},
state=(iterate(self.a), iterate(self.b))) where T
a_result, b_result = state
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), (iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), (iterate(self.a, a_state), b_result)
end
return T(b_curr), (a_result, iterate(self.b, b_state))
end
此代码有效,但类型不稳定,因为 Julia 迭代工具本身就是如此。对于大多数情况,编译器可以自动解决这个问题,但是,在这里它不起作用:以下测试代码说明创建了临时文件:
>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)
记下分配计数。
除了修改代码并希望编译器能够优化类型歧义之外,还有什么方法可以有效地调试此类情况?有谁知道在这种特定情况下有没有创建临时文件的解决方案?
如何诊断问题?
答案:使用@code_warntype
运行:
julia> @code_warntype iterate(x, iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering}
state::Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
@_4::Int64
@_5::Int64
@_6::Union{}
@_7::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Tuple{Int64, Int64}
a_result::Tuple{Int64, Int64}
Body::Tuple{Int64, Any}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.indexed_iterate(state, 1)::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(2)])
│ (a_result = Core.getfield(%9, 1))
│ (@_7 = Core.getfield(%9, 2))
│ %12 = Base.indexed_iterate(state, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(3)])
│ (b_result = Core.getfield(%12, 1))
│ %14 = (b_result === Main.nothing)::Core.Const(false)
└── goto #3 if not %14
2 ─ Core.Const(:(a_result === Main.nothing))
│ Core.Const(:(%16))
│ Core.Const(:(return Main.nothing))
│ Core.Const(:(Base.indexed_iterate(a_result, 1)))
│ Core.Const(:(a_curr = Core.getfield(%19, 1)))
│ Core.Const(:(@_6 = Core.getfield(%19, 2)))
│ Core.Const(:(Base.indexed_iterate(a_result, 2, @_6)))
│ Core.Const(:(a_state = Core.getfield(%22, 1)))
│ Core.Const(:(($(Expr(:static_parameter, 1)))(a_curr)))
│ Core.Const(:(Base.getproperty(self, :a)))
│ Core.Const(:(Main.iterate(%25, a_state)))
│ Core.Const(:(Core.tuple(%26, b_result)))
│ Core.Const(:(Core.tuple(%24, %27)))
└── Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (b_curr = Core.getfield(%30, 1))
│ (@_5 = Core.getfield(%30, 2))
│ %33 = Base.indexed_iterate(b_result, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (b_state = Core.getfield(%33, 1))
│ %35 = (a_result !== Main.nothing)::Core.Const(true)
└── goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%37, 1))
│ (@_4 = Core.getfield(%37, 2))
│ %40 = Base.indexed_iterate(a_result, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%40, 1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44, %45, b_curr)::Bool
└── goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %49 = Base.getproperty(self, :a)::Vector{Int64}
│ %50 = Main.iterate(%49, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %51 = Core.tuple(%50, b_result)::Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}
│ %52 = Core.tuple(%48, %51)::Tuple{Int64, Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}}
└── return %52
6 ┄ %54 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│ %55 = a_result::Tuple{Int64, Int64}
│ %56 = Base.getproperty(self, :b)::Vector{Int64}
│ %57 = Main.iterate(%56, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %58 = Core.tuple(%55, %57)::Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}
│ %59 = Core.tuple(%54, %58)::Tuple{Int64, Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}}
└── return %59
并且您看到 return 值的类型太多,因此 Julia 放弃了对它们进行特化(并假设 return 类型的第二个元素是 Any
)。
如何解决这个问题?
答案:减少iterate
的return类型选项的个数。
这是一个快速写的(我不认为它是最简洁的并且没有广泛测试它所以可能存在一些错误,但它足够简单,可以使用您的代码快速编写以展示如何处理你的问题;请注意,当其中一个集合为空时我使用特殊分支,因为这样应该更快地迭代一个集合):
struct MergeSorted{T,A,B,O,F1,F2}
a::A
b::B
order::O
fa::F1
fb::F2
function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
fa, fb = iterate(a), iterate(b)
F1 = typeof(fa)
F2 = typeof(fb)
new{promote_type(eltype(A),eltype(B)),A,B,O,F1,F2}(a, b, order, fa, fb)
end
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
struct State{Ta, Tb}
a::Union{Nothing, Ta}
b::Union{Nothing, Tb}
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,Nothing}) where {T,A,B,O}
return nothing
end
function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}) where {T,A,B,O,F1}
return self.fa
end
function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}, state) where {T,A,B,O,F1}
return iterate(self.a, state)
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}) where {T,A,B,O,F2}
return self.fb
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}, state) where {T,A,B,O,F2}
return iterate(self.b, state)
end
@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}) where {T,A,B,O,F1,F2}
a_result, b_result = self.fa, self.fb
return iterate(self, State{F1,F2}(a_result, b_result))
end
@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2},
state::State{F1,F2}) where {T,A,B,O,F1,F2}
a_result, b_result = state.a, state.b
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
end
return T(b_curr), State{F1,F2}(a_result, iterate(self.b, b_state))
end
现在你有:
julia> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
julia> sum(x)
269
julia> @allocated sum(x)
0
julia> @code_warntype iterate(x, iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering, Tuple{Int64, Int64}, Tuple{Int64, Int64}}
state::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
@_4::Int64
@_5::Int64
@_6::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Union{Nothing, Tuple{Int64, Int64}}
a_result::Union{Nothing, Tuple{Int64, Int64}}
Body::Union{Nothing, Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.getproperty(state, :a)::Union{Nothing, Tuple{Int64, Int64}}
│ %10 = Base.getproperty(state, :b)::Union{Nothing, Tuple{Int64, Int64}}
│ (a_result = %9)
│ (b_result = %10)
│ %13 = (b_result === Main.nothing)::Bool
└── goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└── goto #4 if not %15
3 ─ return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%18, 1))
│ (@_6 = Core.getfield(%18, 2))
│ %21 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%21, 1))
│ %23 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %24 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %25 = Base.getproperty(self, :a)::Vector{Int64}
│ %26 = Main.iterate(%25, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %27 = (%24)(%26, b_result::Core.Const(nothing))::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %28 = Core.tuple(%23, %27)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (b_curr = Core.getfield(%30, 1))
│ (@_5 = Core.getfield(%30, 2))
│ %33 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (b_state = Core.getfield(%33, 1))
│ %35 = (a_result !== Main.nothing)::Bool
└── goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%37, 1))
│ (@_4 = Core.getfield(%37, 2))
│ %40 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%40, 1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44, %45, b_curr)::Bool
└── goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %49 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %50 = Base.getproperty(self, :a)::Vector{Int64}
│ %51 = Main.iterate(%50, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %52 = (%49)(%51, b_result::Tuple{Int64, Int64})::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %53 = Core.tuple(%48, %52)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %53
8 ┄ %55 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│ %56 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %57 = a_result::Union{Nothing, Tuple{Int64, Int64}}
│ %58 = Base.getproperty(self, :b)::Vector{Int64}
│ %59 = Main.iterate(%58, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %60 = (%56)(%57, %59)::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %61 = Core.tuple(%55, %60)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %61
编辑:现在我意识到我的实现并不完全正确,因为它假设 iterate
的 return 值如果不是 nothing
是稳定类型(它不一定是)。但如果它的类型不稳定,那么编译器必须分配。因此,一个完全正确的解决方案将首先检查 iterate 是否类型稳定。如果是 - 使用我的解决方案,如果不是 - 使用例如你的解决方案。