Python 为具有“@”(矩阵乘法)的对象键入提示
Python type hint for objects that have "@" (matrix-multiply)
我有一个函数 fun()
接受 NumPy ArrayLike 和一个“矩阵”,以及 returns 一个 numpy 数组。
from numpy.typing import ArrayLike
import numpy as np
def fun(A, x: ArrayLike) -> np.ndarray:
return (A @ x) ** 2 - 27.0
对于具有 @
操作的实体,正确的 type
是什么?请注意 fun()
也可以接受 scipy.sparse;也许更多。
您可以使用 typing.Protocol
断言该类型实现了 __matmul__
。
class SupportsMatrixMultiplication(typing.Protocol):
def __matmul__(self, x):
...
def fun(A: SupportsMatrixMultiplication, x: ArrayLike) -> np.ndarray:
return (A @ x) ** 2 - 27.0
我相信,如果您想要的不仅仅是支持 @
作为运算符,您可以通过为 x
和 return 类型提示提供类型提示来进一步完善它。
我有一个函数 fun()
接受 NumPy ArrayLike 和一个“矩阵”,以及 returns 一个 numpy 数组。
from numpy.typing import ArrayLike
import numpy as np
def fun(A, x: ArrayLike) -> np.ndarray:
return (A @ x) ** 2 - 27.0
对于具有 @
操作的实体,正确的 type
是什么?请注意 fun()
也可以接受 scipy.sparse;也许更多。
您可以使用 typing.Protocol
断言该类型实现了 __matmul__
。
class SupportsMatrixMultiplication(typing.Protocol):
def __matmul__(self, x):
...
def fun(A: SupportsMatrixMultiplication, x: ArrayLike) -> np.ndarray:
return (A @ x) ** 2 - 27.0
我相信,如果您想要的不仅仅是支持 @
作为运算符,您可以通过为 x
和 return 类型提示提供类型提示来进一步完善它。