如何为数据 类 实现 pytest.approx()

How to Implement pytest.approx() for Data Classes

假设我有一个 Python 数据 Class 我想用 pytest 测试:

@dataclass
class ExamplePoint:
    x: float
    y: float

通过简单的数学运算就可以了:

p1 = ExamplePoint(1,2)
p2 = ExamplePoint(0.5 + 0.5, 2)
p1 == p2   # True

但是浮点运算很快就会出问题:

p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
p1 == p3  #  False

使用 pytest,您可以使用 approx() 函数绕过此问题:

2.0 == approx(math.sqrt(2) * math.sqrt(2))  # True

你不能简单地将其扩展到数据 Class:

p1 = approx(p3) # Results error: TypeError: cannot make approximate comparisons to non-numeric values: ExamplePoint(x=1, y=2.0000000000000004) 

我目前解决这个问题的方法是在数据 Class 上编写一个 approx() 函数,如下所示:

from dataclasses import astuple, dataclass
import pytest

@dataclass
class ExamplePoint:
    x: float
    y: float
    
    def approx(self, other):
        return astuple(self) == pytest.approx(astuple(other))
    
p1 = ExamplePoint(1,2)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
p1.approx(p3)  # True

我不喜欢这个解决方案,因为 ExamplePoint 现在依赖于 pytest。好像不对。

我如何扩展 pytest,以便 approx() 使用我的数据 Class 而无需数据 Class 了解 pytest?

您可以使用 math.isclose(),它还允许您为您认为接近的内容设置容差。在下面的代码中,我将它分别应用于数据类之外的 ExamplePoint 的 x 和 y 坐标,但您可以采用不同的方式实现它:

@dataclass
class ExamplePoint:
    x: float
    y: float

p1 = ExamplePoint(1, 2)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))

print(math.isclose(p1.x, p3.x, rel_tol=0.01)) #True
print(math.isclose(p1.y, p3.y, rel_tol=0.01)) #True

更新:以下是如何将其合并到您的 approx 函数中:

from dataclasses import astuple, dataclass
import math
@dataclass
class ExamplePoint:
    x: float
    y: float

    def approx(self, other):
        return math.isclose(self.x,other.x, rel_tol=0.001) \
               and math.isclose(self.y,other.y, rel_tol=0.001)


p1 = ExamplePoint(1, 2)
p2 = ExamplePoint(1,1.99)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))

print(p1.approx(p2)) #False
print(p1.approx(p3)) #True

深入研究 pytest 代码(参见 https://github.com/pytest-dev/pytest/blob/main/src/_pytest/python_api.py ),似乎 pytest 检查预期值是否为 Iterable 和 Sizeable。我可以按如下方式使我的数据 Class 可迭代且可调整大小。

from dataclasses import astuple, dataclass
import pytest
import math

@dataclass
class ExamplePoint:
    x: float
    y: float

    def approx(self, other):
        return astuple(self) == pytest.approx(astuple(other))

    def __iter__(self):
        return iter(astuple(self))

    def __len__(self):
        return len(astuple(self))

p1 = ExamplePoint(1,2)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))

assert p1 == pytest.approx(p3)  # True

您应该将近似值添加到数据类的每个参数中:

from dataclasses import dataclass
import pytest

@dataclass
class Foo:
    a: int
    b: float

a = Foo(1, 3.00000001)
b = Foo(1, pytest.approx(3.0, abs=1e-3))
print(a == b)

参见:https://github.com/pytest-dev/pytest/issues/6632#issuecomment-580487103