Python 为具有“@”(矩阵乘法)的对象键入提示

Python type hint for objects that have "@" (matrix-multiply)

我有一个函数 fun() 接受 NumPy ArrayLike 和一个“矩阵”,以及 returns 一个 numpy 数组。

from numpy.typing import ArrayLike
import numpy as np

def fun(A, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

对于具有 @ 操作的实体,正确的 type 是什么?请注意 fun() 也可以接受 scipy.sparse;也许更多。

您可以使用 typing.Protocol 断言该类型实现了 __matmul__

class SupportsMatrixMultiplication(typing.Protocol):
    def __matmul__(self, x):
        ...


def fun(A: SupportsMatrixMultiplication, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

我相信,如果您想要的不仅仅是支持 @ 作为运算符,您可以通过为 x 和 return 类型提示提供类型提示来进一步完善它。