如何限制数据类属性中的值?

How to limit values in dataclass attribute?

我有以下数据类 Gear,我想将 gear_level 的最大值从 0 限制到 5。但是正如你所看到的,当我增加 gear_level 时,它会高于5,这不是我想要的。我尝试了方法以及 postinit。我该如何解决这个问题?

from dataclasses import dataclass

@dataclass
class Gear:
    gear_level: int = 0
    direction: str = None
    # more codes ...

    def __postinit__(self):
        if self.gear_level <= 0:
            self.gear_level = 0
        elif 5 > self.gear_level > 0:
            self.gear_level = self.gear_level
        else:
            self.gear_level = 5

    def set_gear_level(self, level):
        if level <= 0:
            self.gear_level = 0
        elif 5 > level > 0:
            self.gear_level = level
        else:
            self.gear_level = 5


g = Gear()

g.set_gear_level(6)

print(g)

g.gear_level += 1

print(g)

g.set_gear_level(-1)

print(g)

g.gear_level -= 1
print(g)

理想情况下,我更喜欢使用 g.gear_level += 1 表示法,因为我想递增 gear_level。它不应该从档位 1 跳到 5。此外,当它递减时,它应该停在 0。它应该既接受赋值 0 又允许递减到 0。这可以做到吗?

Gear(gear_level=5, direction=None)
Gear(gear_level=6, direction=None)
Gear(gear_level=0, direction=None)
Gear(gear_level=-1, direction=None)

评论中建议的 link 提供了解决此问题的 优雅 解决方案,例如使用自定义 descriptor class,您只需进行最少的更改即可工作。

例如,下面是我如何定义 BoundsValidator 描述符 class 来检查 class 属性是否在预期的下限和上限内(请注意,这两个范围都是在这种情况下可选):

from typing import Optional

try:
    from typing import get_args
except ImportError:  # Python 3.7
    from typing_extensions import get_args


class BoundsValidator:
    """Descriptor to validate an attribute x remains within a specified bounds.

    That is, checks the constraint `low <= x <= high` is satisfied. Note that
    both low and high are optional. If none are provided, no bounds will be
    applied.
    """
    __slots__ = ('name',
                 'type',
                 'validator')

    def __init__(self, min_val: Optional[int] = None,
                 max_val: Optional[int] = float('inf')):

        if max_val is None:     # only minimum
            def validator(name, val):
                if val < min_val:
                    raise ValueError(f"values for {name!r}  have to be > {min_val}; got {val!r}")

        elif min_val is None:   # only maximum
            def validator(name, val):
                if val > max_val:
                    raise ValueError(f"values for {name!r}  have to be < {max_val}; got {val!r}")

        else:                   # both upper and lower bounds are given
            def validator(name, val):
                if not min_val <= val <= max_val:
                    raise ValueError(f"values for {name!r}  have to be within the range "
                                     f"[{min_val}, {max_val}]; got {val!r}")

        self.validator = validator

    def __set_name__(self, owner, name):
        # save the attribute name on an initial run
        self.name = name

        # set the valid types based on the annotation for the attribute
        #   for example, `int` or `Union[int, float]`
        tp = owner.__annotations__[name]
        self.type = get_args(tp) or tp

    def __get__(self, instance, owner):
        if not instance:
            return self
        return instance.__dict__[self.name]

    def __delete__(self, instance):
        del instance.__dict__[self.name]

    def __set__(self, instance, value):
        # can be removed if you don't need the type validation
        if not isinstance(value, self.type):
            raise TypeError(f"{self.name!r} values must be of type {self.type!r}")

        # validate that the value is within expected bounds
        self.validator(self.name, value)

        # finally, set the value on the instance
        instance.__dict__[self.name] = value

最后,这是我想出的示例代码,用于测试它是否按我们预期的方式工作:

from dataclasses import dataclass
from typing import Union


@dataclass
class Person:
    age: int = BoundsValidator(1)   # let's assume a person must at least be 1 years
    num: Union[int, float] = BoundsValidator(-1, 1)
    gear_level: int = BoundsValidator(0, 5)


def main():
    p = Person(10, 0.7, 5)
    print(p)

    # should raise a ValueError now
    try:
        p.gear_level += 1
    except ValueError as e:
        print(e)

    # and likewise here, for the lower bound
    try:
        p.gear_level -= 7
    except ValueError as e:
        print(e)

    # all these should now raise an error

    try:
        _ = Person(0, 0, 2)
    except ValueError as e:
        print(e)

    try:
        _ = Person(120, -3.1, 2)
    except ValueError as e:
        print(e)


if __name__ == '__main__':
    main()

当我们 运行 代码时,这会提供以下输出:

Person(age=10, num=0.7, gear_level=5)
values for 'gear_level'  have to be within the range [0, 5]; got 6
values for 'gear_level'  have to be within the range [0, 5]; got -2
values for 'age'  have to be within the range [1, inf]; got 0
values for 'num'  have to be within the range [-1, 1]; got -3.1

在这种情况下,我会简单地使用 属性:

@dataclass
class Gear:
    gear_level: int

    # Rest of the class excluded for simplicity

    @property
    def gear_level(self) -> int:
        return self._gear_level

    @gear_level.setter
    def gear_level(self, value: int) -> None:
        self._gear_level = min(max(value, 0), 5)

这样你就不需要写 __post_init__ 或者必须记住调用特定的方法:对 gear_level 的赋值将被保留 0 <= gear_level <= 5,即使 [=14] =].