Python 3 保证数组参数维数正确的方法是什么?

What is the Python 3 way to ensure the correct dimension of array arguments?

在我的新手Python 3.7项目中,很多函数的参数都是numpy.ndarray的。这些必须是二维 r x n 矩阵。行维度 r 是必不可少的:一些函数需要 1 x n 向量,其他函数需要 2 x n 矩阵,r 最多三个甚至更多。还为任何 r x n 数组定义了函数。 (列维度 n 对于设计目的不是必需的。)

根据我的 Matlab 经验,此要求可能会令人困惑且容易出错。所以我考虑了以下方法:

  1. 记录方法参数(当然!)
  2. 单元测试(当然!)
  3. 在某些函数中进行验证并抛出异常。 (但是,这不是很实用,也不是很高效。)
  4. 定义数据 类:OneRowTwoRowsThreeRowsFourPlusRows。每个都有一个 ndarray 字段,在构造函数中验证。好处包括类型提示和更好的领域建模,如 DDD。缺点是额外的复杂性。

问题: 鉴于 Python 3 中引入的类型提示和函数式编程的趋势,当前 pythonic 方法是什么来解决这个问题?

Python 最好的地方之一是 duck typing,Numpy 通常与该设计方法非常兼容。假设您有一个纯向量函数 vecfunc。您可以在函数的开头添加一些样板,将任何一维数组膨胀为 1 x n 向量:

def vecfunc(arr):
    if arr.ndim==1:
        arr = arr[None, :]

    ...function body goes here...

这将避免由于 arr 维度太少而导致的任何问题,并且在大多数情况下可能仍会给出正确的行为。但是,它不会阻止用户传入 r x n x m 数组或 15 x n 数组。最终,您将不得不使用方法 3. 来处理这些东西,并在适当的地方抛出一些异常。例如:

def vecfunc(arr):
    if not 0 < arr.ndim < 3:
        raise ValueError("arr must have ndim of 1 or 2. arr.ndim: %d" % arr.ndim)
    elif arr.ndim==1:
        arr = arr[None, :]

如果这让您感觉好些,numpy and scipy 的代码库在许多函数中都有这种基于形状的异常检查,无论何时何地都需要它们。

当然,您始终可以在开发任何给定函数的最后阶段停止添加这些类型的异常检查。您可能会对产生合理行为的输入范围感到惊讶。

如果您对类型注释死心塌地,可以通过 writing your code using Cython 获得类似的东西。例如,如果您想要一个只接受二维整数数组的 add 函数,您可以在 .pyx 文件中编写以下函数:

import numpy as np

def add(long[:, :] arr1, long[:, :] arr2):
    assert tuple(arr1.shape) == tuple(arr2.shape)

    result = np.zeros((arr1.shape[0], arr1.shape[1]), dtype=np.long)
    cdef long[:, :] result_view = result

    for x in range(arr1.shape[0]):
        for y in range(arr1.shape[1]):
            result_view[x, y] = arr1[x, y] + arr2[x, y]

    return result

有关编写和编译 Cython 的更多详细信息,请参阅上面链接的文档。

这与其说是 "type annotations",不如说是真正的强类型化,但它可以满足您的要求。可悲的是,我找不到固定单个维度大小的方法,只能找到总维度数。