Dataclasses:将通用 TypeVar 名称与源中的属性匹配 class
Dataclasses: Matching Generic TypeVar names to attributes in the origin class
假设我有一个如下所示的通用数据类:
from dataclasses import dataclass
from typing import TypeVar, Generic
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
bar: T
IntStrClass = Class[int, str]
当我们阅读代码时,您可以看到 IntStrClass
:
T
与 int
对齐,这使得 bar
的类型成为 int
。
U
与 str
对齐,这使得 foo
的类型成为 str
。
但是我怎样才能以编程方式解决这个问题呢?
我一直在研究 typing
模块,但无法从输出中看出如何匹配它们。我拥有的是:
from typing import get_type_hints, get_origin, get_args
print("Class field types:", get_type_hints(get_origin(IntStrClass)))
print("Class generic args:", get_args(IntStrClass))
Class field types: {'foo': ~U, 'bar': ~T}
Class generic args: (<class 'int'>, <class 'str'>)
我在这里缺少的是根据 Class
的定义来确定 T -> int
和 U -> str
。如果我有这些信息,那么我可以推断出 foo
和 bar
.
的正确类型
提前致谢!
这个怎么样?
[在评论中的对话后进行了大量编辑。]
from dataclasses import dataclass
from typing import TypeVar, Generic, get_type_hints, get_args, get_origin
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
spam: str
bar: T
baz: int
IntStrClass = Class[int, str]
def get_annotations(generic_subclass):
generic_origin = get_origin(generic_subclass)
annotations_map = get_type_hints(generic_origin)
generic_args = get_args(generic_subclass)
try:
generic_params = generic_origin.__parameters__
except AttributeError as err:
raise AttributeError(
f"{origin} has no attribute '__parameters__'. "
"The likely cause of this is that the typing module's "
"API for the Generic class has changed "
"since this function was written."
) from err
type_var_map = dict(zip(generic_params, generic_args))
for field, annotation in annotations_map.items():
if isinstance(annotation, TypeVar):
annotations_map[field] = type_var_map[annotation]
return annotations_map
print("Resolved attributes:", get_annotations(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}
我想我可能找到了解决方案,但我不确定它的安全性。通用子类似乎公开了一个 __parameters__
字段,我认为我可以在这里利用它:
def get_hints(clazz):
origin = get_origin(clazz)
hints = get_type_hints(origin)
clazz_args = get_args(clazz)
if hasattr(origin, "__parameters__"):
typevars = origin.__parameters__
for typevar, resolved_typevar in zip(typevars, clazz_args):
for attr_name in hints:
if hints[attr_name] == typevar:
hints[attr_name] = resolved_typevar
return hints
使用 Alex Waygood 的示例进行测试:
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
spam: str
bar: T
baz: int
IntStrClass = Class[int, str]
print("Resolved attributes:", get_hints(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}
我确定我遗漏了一些特殊情况。
假设我有一个如下所示的通用数据类:
from dataclasses import dataclass
from typing import TypeVar, Generic
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
bar: T
IntStrClass = Class[int, str]
当我们阅读代码时,您可以看到 IntStrClass
:
T
与int
对齐,这使得bar
的类型成为int
。U
与str
对齐,这使得foo
的类型成为str
。
但是我怎样才能以编程方式解决这个问题呢?
我一直在研究 typing
模块,但无法从输出中看出如何匹配它们。我拥有的是:
from typing import get_type_hints, get_origin, get_args
print("Class field types:", get_type_hints(get_origin(IntStrClass)))
print("Class generic args:", get_args(IntStrClass))
Class field types: {'foo': ~U, 'bar': ~T}
Class generic args: (<class 'int'>, <class 'str'>)
我在这里缺少的是根据 Class
的定义来确定 T -> int
和 U -> str
。如果我有这些信息,那么我可以推断出 foo
和 bar
.
提前致谢!
这个怎么样?
[在评论中的对话后进行了大量编辑。]
from dataclasses import dataclass
from typing import TypeVar, Generic, get_type_hints, get_args, get_origin
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
spam: str
bar: T
baz: int
IntStrClass = Class[int, str]
def get_annotations(generic_subclass):
generic_origin = get_origin(generic_subclass)
annotations_map = get_type_hints(generic_origin)
generic_args = get_args(generic_subclass)
try:
generic_params = generic_origin.__parameters__
except AttributeError as err:
raise AttributeError(
f"{origin} has no attribute '__parameters__'. "
"The likely cause of this is that the typing module's "
"API for the Generic class has changed "
"since this function was written."
) from err
type_var_map = dict(zip(generic_params, generic_args))
for field, annotation in annotations_map.items():
if isinstance(annotation, TypeVar):
annotations_map[field] = type_var_map[annotation]
return annotations_map
print("Resolved attributes:", get_annotations(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}
我想我可能找到了解决方案,但我不确定它的安全性。通用子类似乎公开了一个 __parameters__
字段,我认为我可以在这里利用它:
def get_hints(clazz):
origin = get_origin(clazz)
hints = get_type_hints(origin)
clazz_args = get_args(clazz)
if hasattr(origin, "__parameters__"):
typevars = origin.__parameters__
for typevar, resolved_typevar in zip(typevars, clazz_args):
for attr_name in hints:
if hints[attr_name] == typevar:
hints[attr_name] = resolved_typevar
return hints
使用 Alex Waygood 的示例进行测试:
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
spam: str
bar: T
baz: int
IntStrClass = Class[int, str]
print("Resolved attributes:", get_hints(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}
我确定我遗漏了一些特殊情况。