在 Numpy 和 PyTorch 之间切换 function/class 实现:?

Switch function/class implementation between Numpy & PyTorch:?

我有一个函数(实际上是一个 class,但为了简单起见,我们假设它是一个函数)使用 PyTorch 中都存在的几个 NumPy 操作,例如np.add 而且我还想要一个 PyTorch 版本的函数。我试图避免重复我的代码,所以我想知道:

有没有办法让我在 NumPy 和 PyTorch 之间来回动态切换函数的执行而不需要重复实现?

举个小例子,假设我的函数是:

def foo_numpy(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  return np.add(x, y)

我可以定义一个 PyTorch 等价物:

def foo_torch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  return torch.add(x, y)

我能以某种方式定义一个函数吗:

def foo(x, y, mode: str = 'numpy'):
  if mode == 'numpy':
    return np.add(x, y)
  elif mode == 'torch':
    return torch.add(x, y)
  else:
    raise ValueError

不需要 if-else 语句?

编辑:像下面这样的东西怎么样?

def foo(x, y, mode: str = 'numpy'):
  if mode == 'numpy':
    lib = np
  elif mode == 'torch':
    lib = torch
  else:
    raise ValueError
  return lib.add(x, y)

您可以使用布尔值 (bool) 来表示您要使用的模式,而不是使用字符串,即 False (0) 表示 NumPy,True (1) 表示 PyTorch。然后可以使用三元运算符进一步缩小 if 语句。

def foo(x, y, mode: bool = 0):
    lib = torch if mode else np
    return lib.add(x, y) 

如果你想在class两者之间来回切换,你可以做类似的事情

class Example:

    def __init__(self):
        self._mode = True
    
    def switchMode(self):
        self._mode = !self._mode

    def foo(self, x, y):
        lib = torch if self._mode else np
        return lib.add(x, y)