Dataclasses:将通用 TypeVar 名称与源中的属性匹配 class

Dataclasses: Matching Generic TypeVar names to attributes in the origin class

假设我有一个如下所示的通用数据类:

from dataclasses import dataclass
from typing import TypeVar, Generic

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class Class(Generic[T, U]):
    foo: U
    bar: T


IntStrClass = Class[int, str]

当我们阅读代码时,您可以看到 IntStrClass:

但是我怎样才能以编程方式解决这个问题呢?

我一直在研究 typing 模块,但无法从输出中看出如何匹配它们。我拥有的是:

from typing import get_type_hints, get_origin, get_args

print("Class field types:", get_type_hints(get_origin(IntStrClass)))
print("Class generic args:", get_args(IntStrClass))
Class field types: {'foo': ~U, 'bar': ~T}
Class generic args: (<class 'int'>, <class 'str'>)

我在这里缺少的是根据 Class 的定义来确定 T -> intU -> str。如果我有这些信息,那么我可以推断出 foobar.

的正确类型

提前致谢!

这个怎么样?

[在评论中的对话后进行了大量编辑。]

from dataclasses import dataclass
from typing import TypeVar, Generic, get_type_hints, get_args, get_origin

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class Class(Generic[T, U]):
    foo: U
    spam: str
    bar: T
    baz: int


IntStrClass = Class[int, str]

def get_annotations(generic_subclass):
    generic_origin = get_origin(generic_subclass)
    annotations_map = get_type_hints(generic_origin)
    generic_args = get_args(generic_subclass)

    try:
        generic_params = generic_origin.__parameters__
    except AttributeError as err:
        raise AttributeError(
            f"{origin} has no attribute '__parameters__'. "
            "The likely cause of this is that the typing module's "
            "API for the Generic class has changed "
            "since this function was written."
            ) from err

    type_var_map = dict(zip(generic_params, generic_args))
    
    for field, annotation in annotations_map.items():
        if isinstance(annotation, TypeVar):
            annotations_map[field] = type_var_map[annotation]
            
    return annotations_map

print("Resolved attributes:", get_annotations(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}

我想我可能找到了解决方案,但我不确定它的安全性。通用子类似乎公开了一个 __parameters__ 字段,我认为我可以在这里利用它:

def get_hints(clazz):
    origin = get_origin(clazz)
    hints = get_type_hints(origin)
    clazz_args = get_args(clazz)
    if hasattr(origin, "__parameters__"):
        typevars = origin.__parameters__
        for typevar, resolved_typevar in zip(typevars, clazz_args):
            for attr_name in hints:
                if hints[attr_name] == typevar:
                    hints[attr_name] = resolved_typevar
    return hints

使用 Alex Waygood 的示例进行测试:

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class Class(Generic[T, U]):
    foo: U
    spam: str
    bar: T
    baz: int


IntStrClass = Class[int, str]
print("Resolved attributes:", get_hints(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}

我确定我遗漏了一些特殊情况。