不同的 return 值取决于通用常量的值

Different return values dependent on value of generic constant

让我们考虑这样一个函数:

fn test<const N: usize>() -> [f64; N] {
    if N == 1 {
        [0.0_f64; 1]
    } else if N == 2 {
        [1.0_f64; 2]
    } else {
        panic!()
    }
}

我的理解是编译器会在编译时评估N的值。如果是这种情况,if 语句也可以在编译时求值,因此正确的类型应该是 returned,因为 [0.0_f64; 1] 只有 returned if [=如果 N == 2.

,则 17=] 和 [1.0_f64; 2] 仅 returned

现在,当我尝试编译这段代码时,编译器失败了,基本上告诉我 returned 数组的维度是错误的,因为它们没有明确地将 N 作为长度。

我意识到,我可以将这个具体示例实现为

fn test<const N: usize>() -> [f64; N] {
    match N {
        1 => { [0.0_f64; N] },
        2 => { [1.0_f64; N] },
        _ => { panic!("Invalid value {}", N) },
    }
}

但这在我的实际代码中不起作用,因为它为不同的分支使用具有固定数组大小的不同函数。

有没有办法做到这一点?也许使用像 #![cfg] makro?

这样的东西

为了澄清为什么我的问题不起作用,让我们把它写出来:

fn some_fct() -> [f64; 1] {
    [0.0_f64; 1]
}
fn some_other_fct() -> [f64; 2] {
    [1.0_f64; 2]
}

fn test<const N: usize>() -> [f64; N] {
    match N {
        1 => some_fct(),
        2 => some_other_fct(),
        _ => {
            panic!("Invalid value {}", N)
        }
    }
}

而且由于程序结构中的其他限制,我无法真正将 some_fct()some_other_fct() 写入具有通用大小的 return。

您可以使用通用特征来做到这一点:

trait Test<const N: usize> {
    fn test() -> [f64; N];
}

然后为零大小的类型实现它:

struct T;

impl Test<1> for T {
    fn test() -> [f64; 1] {
        return [0.0_f64; 1];
    }
}

impl Test<2> for T {
    fn test() -> [f64; 2] {
        return [1.0_f64; 2];
    }
}

缺点是调用有点麻烦:

fn main() {
    dbg!(<T as Test<1>>::test());
    dbg!(<T as Test<2>>::test());
}

但是正如下面@eggyal 的评论,您可以添加一个带有 well-written 绑定的通用函数以获得所需的语法:

fn test<const N: usize>() -> [f64; N]
where
    T: Test<N>
{
    T::test()
}
fn main() {
    dbg!(test::<1>());
    dbg!(test::<2>());
}

现在,您没有“使用错误的 Npanic!” 的行为。考虑一个功能而不是限制:如果你使用错误 N 你的代码将无法编译而不是在运行时出现恐慌。

如果您真的想要 panic!() 行为,您可以使用 #![feature(specialization)] 的不稳定特性来获得它,只需将 default 添加到此实现中即可:

impl<const N: usize> Test<N> for T {
    default fn test() -> [f64; N] {
        panic!();
    }
}

但是该功能被明确标记为不完整,所以我还不能指望它。

这里有一个解决方案不是特别聪明,但很容易理解并且类似于原始解决方案:

fn test<const N: usize>() -> [f64; N] {
    match N {
        1 => some_fct().as_slice().try_into().unwrap(),
        2 => some_other_fct().as_slice().try_into().unwrap(),
        _ => {
            panic!("Invalid value {}", N)
        }
    }
}

尽管代码看起来像是在 运行 时检查数组大小,但 godbolt 表明 rustc/LLVM 能够推断出 [f64; N].as_slice().try_into() 总是成功强制 array-turned-slice 到 [f64; N]test<1>test<2> 的生成代码因此不包含检查或恐慌,而 N>2test<N> 只是由于 catch-all 匹配臂中的恐慌而无条件地恐慌。