涉及大小写区分的函数的类型稳定性

Type stability for a function involving case distinctions

我正在编写一个函数来计算重心插值公式的权重。忽略类型稳定性,这很简单:

function baryweights(x)
    n = length(x)
    if n == 1; return [1.0]; end # This is obviously not type stable

    xmin,xmax = extrema(x)
    x *= 4/(xmax-xmin)
    # ^ Multiply by capacity of interval to avoid overflow
    return [
        1/prod(x[i]-x[j] for j in 1:n if j != i)
        for i = 1:n
    ]
end

类型稳定性的问题是计算出 n > 1 案例的 return 类型,这样我就可以 return 一个 n == 1 中正确类型的数组案子。有没有简单的技巧可以实现这一点?

只需在一个伪参数上递归调用函数:

function baryweights(x)
    n = length(x)
    if n == 1
        T = eltype(baryweights(zeros(eltype(x),2)))
        return [one(T)]
    end

    xmin,xmax = extrema(x)
    let x = 4/(xmax-xmin) * x
        # ^ Multiply by capacity of interval to avoid overflow,
        #   and wrap in let to avoid another source of type instability
        #   (https://github.com/JuliaLang/julia/issues/15276)
        return [
            1/prod(x[i]-x[j] for j in 1:n if j != i)
            for i = 1:n
        ]
    end
end

我不确定我是否理解你的计划。但也许这样的事情可以帮助? ->

baryone(t::T) where T<:Real = [1.]
baryone(t::T) where T<:Complex = [1im]  # or whatever you like here

function baryweights(x::Array{T,1}) where T<:Number
    n = length(x)
    n == 1 && return baryone(x[1])
    xmin,xmax = extrema(x)  # don't forget fix extrema for complex! :)
    x *= 4/(xmax-xmin)
    # ^ Multiply by capacity of interval to avoid overflow
    return [
        1/prod(x[i]-x[j] for j in 1:n if j != i)
        for i = 1:n
    ]
end

警告:我还是新手!如果我尝试 @code_warntype baryweights([1]) 我只会看到很多警告。 (另外,如果我避免调用 baryone)。例如 nAny !!

编辑: 我 asked on discourse 现在看到 @code_warn return 如果我们使用另一个变量 (y) 会得到更好的结果:

function baryweights(x::Array{T,1}) where T<:Number
    n = length(x)
    n == 1 && return baryone(x[1])
    xmin,xmax = extrema(x)  # don't forget fix extrema for complex! :)
    let y = x * 4/(xmax-xmin)
        # ^ Multiply by capacity of interval to avoid overflow
        return [
            1/prod(y[i]-y[j] for j in 1:n if j != i)
            for i = 1:n
        ]
    end
end

Edit2:我添加了 let 以避免 yCore.Boxed