python 根据变量值使用不同的装饰器参数
python use different decorator parameters dependent on a variable value
示例代码如下:
@deviceCountAtLeast(1)
if NO_DOUBLE:
@dtypes(torch.float)
else:
@dtypes(torch.float, torch.double)
def test_requires_grad_factory(self, devices, dtype):
fns = [torch.ones_like, torch.testing.randn_like]
x = torch.randn(2, 3, dtype=dtype, device=devices[0])
for fn in fns:
for requires_grad in [True, False]:
output = fn(x, dtype=dtype, device=devices[0], requires_grad=requires_grad)
self.assertEqual(requires_grad, output.requires_grad)
self.assertIs(dtype, output.dtype)
self.assertEqual(devices[0], str(x.device))
如您所见,我想根据 NO_DOUBLE
值选择 @dtypes()
装饰器的参数列表。
我目前的解决方法就像使用另一个函数来 return 不同的装饰器:
def no_double(cond, dec1, dec2):
return dec1 if cond else dec2
@no_double(NO_DOUBLE, dtypes(torch.float), dtypes(torch.float, torch.double))
def test_requires_grad_factory(self, devices, dtype):
从 Python3.9 开始,您可以使用任何表达式作为装饰器,请参阅 PEP 614
在使用的装饰器是三元表达式的结果的情况下,像下面这样的东西应该可以工作
@(dtypes(torch.float) if NO_DOUBLE else dtypes(torch.float, torch.double))
def test_requires_grad_factory(self, devices, dtype):
...
示例代码如下:
@deviceCountAtLeast(1)
if NO_DOUBLE:
@dtypes(torch.float)
else:
@dtypes(torch.float, torch.double)
def test_requires_grad_factory(self, devices, dtype):
fns = [torch.ones_like, torch.testing.randn_like]
x = torch.randn(2, 3, dtype=dtype, device=devices[0])
for fn in fns:
for requires_grad in [True, False]:
output = fn(x, dtype=dtype, device=devices[0], requires_grad=requires_grad)
self.assertEqual(requires_grad, output.requires_grad)
self.assertIs(dtype, output.dtype)
self.assertEqual(devices[0], str(x.device))
如您所见,我想根据 NO_DOUBLE
值选择 @dtypes()
装饰器的参数列表。
我目前的解决方法就像使用另一个函数来 return 不同的装饰器:
def no_double(cond, dec1, dec2):
return dec1 if cond else dec2
@no_double(NO_DOUBLE, dtypes(torch.float), dtypes(torch.float, torch.double))
def test_requires_grad_factory(self, devices, dtype):
从 Python3.9 开始,您可以使用任何表达式作为装饰器,请参阅 PEP 614
在使用的装饰器是三元表达式的结果的情况下,像下面这样的东西应该可以工作
@(dtypes(torch.float) if NO_DOUBLE else dtypes(torch.float, torch.double))
def test_requires_grad_factory(self, devices, dtype):
...