class 的类型重载另一个 class 属性
Typings of class overloading another class attributes
我正在尝试创建一个 class,它使用 __getattr__
调用另一个 class 属性,以便包装 class 调用。
from aiohttp import ClientSession
from contextlib import asynccontextmanager
class SessionThrottler:
def __init__(self, session: ClientSession,
time_period: int, max_tasks: int):
self._obj = session
self._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
time_period=time_period)
def __getattr__(self, name):
@asynccontextmanager
async def _do(*args, **kwargs):
async with self._bucket:
res = await getattr(self._obj, name)(*args, **kwargs)
yield res
return _do
async def close(self):
await self._obj.close()
那么我可以这样做:
async def fetch(session: ClientSession):
async with session.get('http://localhost:5051') as resp:
_ = resp
session = ClientSession()
session_throttled = SessionThrottler(session, 4, 2)
await asyncio.gather(
*[fetch(session_trottled)
for _ in range(10)]
)
这段代码工作正常,但我怎样才能将 session_throttled
推断为 ClientSession
而不是 SessionThrottler
(有点像 functools.wraps
)?
我要看你需要什么"is inferred as"。
创建 ClientSessions 的 ThrotledSessions 实例
使用 classes 的自然方式是通过继承 - 如果您的 SessionThrotler
继承自 ClientSession
,那么自然会 be 一个 ClientSession
还有。
"small downside" 是 __getattr__
不会按预期工作,因为仅针对实例中未找到的属性调用 - 而 Python 将 "see" 来自 ClientSession
在您的 ThrotledSession 对象中并改为调用它们。
当然,这也需要您静态继承您的 class,并且您可能希望它动态工作。 (静态地,我的意思是
必须写 class SessionThrotler(ClientSession):
- 或者至少,如果有有限数量的不同 Session classes 你想包装,为每个写一个 subclass 继承自 ThrotledClass 还有:
class ThrotledClientSession(ThrotledSession, ClientSession):
...
如果这对您有用,那么就是通过创建 __getattribute__
而不是 __getattr__
来修复属性访问的问题。两者的区别在于__getattribte__
包含了所有的属性查找步骤,并且在查找开始时调用。而当所有其他方法都失败时,__getattr__
被称为正常查找的一部分(在 __getattribute__
的标准算法中)。
class SessionThrottlerMixin:
def __init__(self, session: ClientSession,
time_period: int, max_tasks: int):
self._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
time_period=time_period)
def __getattribute__(self, name):
attr = super().__getattribute__(name)
if not name.startswith("_") or not callable(attr):
return attr
@asynccontextmanager
async def _do(*args, **kwargs):
async with self._bucket:
res = await attr(*args, **kwargs)
yield res
return _do
class ThrotledClientSession(SessionThrottlerMixin, ClientSession):
pass
如果您从其他代码获取 CLientSession
实例,并且不想或不能用此代码替换基础 class,您可以在所需的位置执行此操作实例,通过分配给 __class__
属性:
如果 ClientSession
是正常的 Python class,它不会从特殊基础继承,如 Python 内置函数,不使用 __slots__
和其他一些限制 -该实例是 "converted" 到 ThrotledClientSession
飞行途中(但您必须执行继承操作):session.__class__ = ThrottledClientSession
.
Class 以这种方式赋值不会 运行 新的 class __init__
。由于您需要创建 _bucket
,您可以使用 class 方法来创建存储桶并进行替换 - 因此,在带有 __getattribute__
的版本中添加如下内容:
class SessionThrottler:
...
@classmethod
def _wrap(cls, instance, time_period: int, max_tasks: int):
cls.__class__ = cls
instance._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
time_period=time_period)
return instance
...
throtled_session = ThrotledClientSession._wrap(session, 4, 2)
如果你有很多父 classes 想要用这种方式包装,并且你不想声明它的 Throttled
版本,这个 可以 动态生成-但如果这是唯一的方法,我只会那样做。最好声明大约 10 个存根 Thotled 版本,每个版本 3 行。
虚拟子classing
如果您可以更改 ClientSession
classes(以及您想要包装的其他代码)的代码,这是最不引人注目的方式 -
Python 有一个晦涩的 OOP 特性,称为 Virtual Subclassing
,其中 class 可以注册为另一个 class 的子 class,而无需真正的继承。但是,要成为 "parent" 的 class 必须具有 abc.ABCMeta
作为其元 class - 否则这真的很不引人注目。
这是它的工作原理:
In [13]: from abc import ABC
In [14]: class A(ABC):
...: pass
...:
In [15]: class B: # don't inherit
...: pass
In [16]: A.register(B)
Out[16]: __main__.B
In [17]: isinstance(B(), A)
Out[17]: True
因此,在您的原始代码中,如果您可以使 ClientSession
继承自 abc.ABC
(完全没有任何其他更改 )- 然后做:
ClientSession.register(SessionThrottler)
它会起作用(如果 "inferred as" 你的意思与对象类型有关)。
请注意,如果 ClientSession 和其他人有不同的 metaclass,添加 abc.ABC
作为其基础之一将失败并出现 metaclass 冲突。如果您可以更改他们的代码,这仍然是更好的方法:只需创建一个继承自两个元classes 的协作元class,然后您都设置好了:
class Session(metaclass=IDontCare):
...
from abc import ABCMeta
class ColaborativeMeta(ABCMeta, Session.__class__):
pass
class ClientSession(Session, metaclass=ColaborativeMeta):
...
类型提示
如果你不需要 "isinstance" 来工作,只需要输入系统相同,那么只需要使用 typing.cast
:
import typing as T
...
session = ClientSession()
session_throttled = T.cast(ClientSession, SessionThrottler(session, 4, 2))
该对象在 运行 时未被触及 - 只是同一个对象,但从那时起,mypy
等工具会将其视为 ClientSession
的实例。
最后但同样重要的是,更改 class 名称。
所以,如果 "inferred as" 你并不是说包装的 class 应该被视为一个实例,而是只关心 class 名称在日志中正确显示和这样,您可以将 class __name__
属性设置为您想要的任何字符串:
class SessionThrottler:
...
SessionThrottelr.__name__ = ClientSession.__name__
或者在包装器上使用适当的 __repr__
方法 class:
class SessionThrottler:
...
def __repr__(self):
return repr(self._obj)
此解决方案基于修补提供给上下文管理器的对象方法(而不是包装它们)。
import asyncio
import functools
import contextlib
class Counter:
async def infinite(self):
cnt = 0
while True:
yield cnt
cnt += 1
await asyncio.sleep(1)
def limited_infinite(f, limit):
@functools.wraps(f)
async def inner(*a, **kw):
cnt = 0
async for res in f(*a, **kw):
yield res
if cnt == limit:
break
cnt += 1
return inner
@contextlib.contextmanager
def throttler(limit, counter):
orig = counter.infinite
counter.infinite = limited_infinite(counter.infinite, limit)
yield counter
counter.infinite = orig
async def main():
with throttler(5, Counter()) as counter:
async for x in counter.infinite():
print('res: ', x)
if __name__ == "__main__":
asyncio.run(main())
对于您的情况,这意味着修补 ClientSession
的每个相关方法(可能仅限于 http 方法)。不确定它是否更好。
我正在尝试创建一个 class,它使用 __getattr__
调用另一个 class 属性,以便包装 class 调用。
from aiohttp import ClientSession
from contextlib import asynccontextmanager
class SessionThrottler:
def __init__(self, session: ClientSession,
time_period: int, max_tasks: int):
self._obj = session
self._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
time_period=time_period)
def __getattr__(self, name):
@asynccontextmanager
async def _do(*args, **kwargs):
async with self._bucket:
res = await getattr(self._obj, name)(*args, **kwargs)
yield res
return _do
async def close(self):
await self._obj.close()
那么我可以这样做:
async def fetch(session: ClientSession):
async with session.get('http://localhost:5051') as resp:
_ = resp
session = ClientSession()
session_throttled = SessionThrottler(session, 4, 2)
await asyncio.gather(
*[fetch(session_trottled)
for _ in range(10)]
)
这段代码工作正常,但我怎样才能将 session_throttled
推断为 ClientSession
而不是 SessionThrottler
(有点像 functools.wraps
)?
我要看你需要什么"is inferred as"。
创建 ClientSessions 的 ThrotledSessions 实例
使用 classes 的自然方式是通过继承 - 如果您的 SessionThrotler
继承自 ClientSession
,那么自然会 be 一个 ClientSession
还有。
"small downside" 是 __getattr__
不会按预期工作,因为仅针对实例中未找到的属性调用 - 而 Python 将 "see" 来自 ClientSession
在您的 ThrotledSession 对象中并改为调用它们。
当然,这也需要您静态继承您的 class,并且您可能希望它动态工作。 (静态地,我的意思是
必须写 class SessionThrotler(ClientSession):
- 或者至少,如果有有限数量的不同 Session classes 你想包装,为每个写一个 subclass 继承自 ThrotledClass 还有:
class ThrotledClientSession(ThrotledSession, ClientSession):
...
如果这对您有用,那么就是通过创建 __getattribute__
而不是 __getattr__
来修复属性访问的问题。两者的区别在于__getattribte__
包含了所有的属性查找步骤,并且在查找开始时调用。而当所有其他方法都失败时,__getattr__
被称为正常查找的一部分(在 __getattribute__
的标准算法中)。
class SessionThrottlerMixin:
def __init__(self, session: ClientSession,
time_period: int, max_tasks: int):
self._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
time_period=time_period)
def __getattribute__(self, name):
attr = super().__getattribute__(name)
if not name.startswith("_") or not callable(attr):
return attr
@asynccontextmanager
async def _do(*args, **kwargs):
async with self._bucket:
res = await attr(*args, **kwargs)
yield res
return _do
class ThrotledClientSession(SessionThrottlerMixin, ClientSession):
pass
如果您从其他代码获取 CLientSession
实例,并且不想或不能用此代码替换基础 class,您可以在所需的位置执行此操作实例,通过分配给 __class__
属性:
如果 ClientSession
是正常的 Python class,它不会从特殊基础继承,如 Python 内置函数,不使用 __slots__
和其他一些限制 -该实例是 "converted" 到 ThrotledClientSession
飞行途中(但您必须执行继承操作):session.__class__ = ThrottledClientSession
.
Class 以这种方式赋值不会 运行 新的 class __init__
。由于您需要创建 _bucket
,您可以使用 class 方法来创建存储桶并进行替换 - 因此,在带有 __getattribute__
的版本中添加如下内容:
class SessionThrottler:
...
@classmethod
def _wrap(cls, instance, time_period: int, max_tasks: int):
cls.__class__ = cls
instance._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
time_period=time_period)
return instance
...
throtled_session = ThrotledClientSession._wrap(session, 4, 2)
如果你有很多父 classes 想要用这种方式包装,并且你不想声明它的 Throttled
版本,这个 可以 动态生成-但如果这是唯一的方法,我只会那样做。最好声明大约 10 个存根 Thotled 版本,每个版本 3 行。
虚拟子classing
如果您可以更改 ClientSession
classes(以及您想要包装的其他代码)的代码,这是最不引人注目的方式 -
Python 有一个晦涩的 OOP 特性,称为 Virtual Subclassing
,其中 class 可以注册为另一个 class 的子 class,而无需真正的继承。但是,要成为 "parent" 的 class 必须具有 abc.ABCMeta
作为其元 class - 否则这真的很不引人注目。
这是它的工作原理:
In [13]: from abc import ABC
In [14]: class A(ABC):
...: pass
...:
In [15]: class B: # don't inherit
...: pass
In [16]: A.register(B)
Out[16]: __main__.B
In [17]: isinstance(B(), A)
Out[17]: True
因此,在您的原始代码中,如果您可以使 ClientSession
继承自 abc.ABC
(完全没有任何其他更改 )- 然后做:
ClientSession.register(SessionThrottler)
它会起作用(如果 "inferred as" 你的意思与对象类型有关)。
请注意,如果 ClientSession 和其他人有不同的 metaclass,添加 abc.ABC
作为其基础之一将失败并出现 metaclass 冲突。如果您可以更改他们的代码,这仍然是更好的方法:只需创建一个继承自两个元classes 的协作元class,然后您都设置好了:
class Session(metaclass=IDontCare):
...
from abc import ABCMeta
class ColaborativeMeta(ABCMeta, Session.__class__):
pass
class ClientSession(Session, metaclass=ColaborativeMeta):
...
类型提示
如果你不需要 "isinstance" 来工作,只需要输入系统相同,那么只需要使用 typing.cast
:
import typing as T
...
session = ClientSession()
session_throttled = T.cast(ClientSession, SessionThrottler(session, 4, 2))
该对象在 运行 时未被触及 - 只是同一个对象,但从那时起,mypy
等工具会将其视为 ClientSession
的实例。
最后但同样重要的是,更改 class 名称。
所以,如果 "inferred as" 你并不是说包装的 class 应该被视为一个实例,而是只关心 class 名称在日志中正确显示和这样,您可以将 class __name__
属性设置为您想要的任何字符串:
class SessionThrottler:
...
SessionThrottelr.__name__ = ClientSession.__name__
或者在包装器上使用适当的 __repr__
方法 class:
class SessionThrottler:
...
def __repr__(self):
return repr(self._obj)
此解决方案基于修补提供给上下文管理器的对象方法(而不是包装它们)。
import asyncio
import functools
import contextlib
class Counter:
async def infinite(self):
cnt = 0
while True:
yield cnt
cnt += 1
await asyncio.sleep(1)
def limited_infinite(f, limit):
@functools.wraps(f)
async def inner(*a, **kw):
cnt = 0
async for res in f(*a, **kw):
yield res
if cnt == limit:
break
cnt += 1
return inner
@contextlib.contextmanager
def throttler(limit, counter):
orig = counter.infinite
counter.infinite = limited_infinite(counter.infinite, limit)
yield counter
counter.infinite = orig
async def main():
with throttler(5, Counter()) as counter:
async for x in counter.infinite():
print('res: ', x)
if __name__ == "__main__":
asyncio.run(main())
对于您的情况,这意味着修补 ClientSession
的每个相关方法(可能仅限于 http 方法)。不确定它是否更好。