class 的类型重载另一个 class 属性

Typings of class overloading another class attributes

我正在尝试创建一个 class,它使用 __getattr__ 调用另一个 class 属性,以便包装 class 调用。

from aiohttp import ClientSession
from contextlib import asynccontextmanager


class SessionThrottler:

    def __init__(self, session: ClientSession,
                 time_period: int, max_tasks: int):
        self._obj = session
        self._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
                                        time_period=time_period)

    def __getattr__(self, name):
        @asynccontextmanager
        async def _do(*args, **kwargs):
            async with self._bucket:
                res = await getattr(self._obj, name)(*args, **kwargs)
                yield res
        return _do

    async def close(self):
        await self._obj.close()

那么我可以这样做:

async def fetch(session: ClientSession):
    async with session.get('http://localhost:5051') as resp:
        _ = resp


session = ClientSession()
session_throttled = SessionThrottler(session, 4, 2)
await asyncio.gather(
    *[fetch(session_trottled) 
      for _ in range(10)]
)

这段代码工作正常,但我怎样才能将 session_throttled 推断为 ClientSession 而不是 SessionThrottler(有点像 functools.wraps)?

我要看你需要什么"is inferred as"。

创建 ClientSessions 的 ThrotledSessions 实例

使用 classes 的自然方式是通过继承 - 如果您的 SessionThrotler 继承自 ClientSession,那么自然会 be 一个 ClientSession 还有。 "small downside" 是 __getattr__ 不会按预期工作,因为仅针对实例中未找到的属性调用 - 而 Python 将 "see" 来自 ClientSession 在您的 ThrotledSession 对象中并改为调用它们。

当然,这也需要您静态继承您的 class,并且您可能希望它动态工作。 (静态地,我的意思是 必须写 class SessionThrotler(ClientSession): - 或者至少,如果有有限数量的不同 Session classes 你想包装,为每个写一个 subclass 继承自 ThrotledClass 还有:

class ThrotledClientSession(ThrotledSession, ClientSession):
    ...

如果这对您有用,那么就是通过创建 __getattribute__ 而不是 __getattr__ 来修复属性访问的问题。两者的区别在于__getattribte__包含了所有的属性查找步骤,并且在查找开始时调用。而当所有其他方法都失败时,__getattr__ 被称为正常查找的一部分(在 __getattribute__ 的标准算法中)。

class SessionThrottlerMixin:

    def __init__(self, session: ClientSession,
                 time_period: int, max_tasks: int):
        self._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
                                        time_period=time_period)

    def __getattribute__(self, name):
        attr = super().__getattribute__(name)
        if not name.startswith("_") or not callable(attr):
             return attr
        @asynccontextmanager
        async def _do(*args, **kwargs):
            async with self._bucket:
                res = await attr(*args, **kwargs)
                yield res
        return _do

    class ThrotledClientSession(SessionThrottlerMixin, ClientSession):
        pass

如果您从其他代码获取 CLientSession 实例,并且不想或不能用此代码替换基础 class,您可以在所需的位置执行此操作实例,通过分配给 __class__ 属性: 如果 ClientSession 是正常的 Python class,它不会从特殊基础继承,如 Python 内置函数,不使用 __slots__ 和其他一些限制 -该实例是 "converted" 到 ThrotledClientSession 飞行途中(但您必须执行继承操作):session.__class__ = ThrottledClientSession.

Class 以这种方式赋值不会 运行 新的 class __init__。由于您需要创建 _bucket,您可以使用 class 方法来创建存储桶并进行替换 - 因此,在带有 __getattribute__ 的版本中添加如下内容:


class SessionThrottler:
    ...
    @classmethod
    def _wrap(cls, instance, time_period: int, max_tasks: int):
       cls.__class__ = cls
       instance._bucket = AsyncLeakyBucket(max_tasks=max_tasks,
                                            time_period=time_period)
       return instance 

    ...

throtled_session = ThrotledClientSession._wrap(session, 4, 2)

如果你有很多父 classes 想要用这种方式包装,并且你不想声明它的 Throttled 版本,这个 可以 动态生成-但如果这是唯一的方法,我只会那样做。最好声明大约 10 个存根 Thotled 版本,每个版本 3 行。

虚拟子classing

如果您可以更改 ClientSession classes(以及您想要包装的其他代码)的代码,这是最不引人注目的方式 -

Python 有一个晦涩的 OOP 特性,称为 Virtual Subclassing,其中 class 可以注册为另一个 class 的子 class,而无需真正的继承。但是,要成为 "parent" 的 class 必须具有 abc.ABCMeta 作为其元 class - 否则这真的很不引人注目。

这是它的工作原理:


In [13]: from abc import ABC                                                                                                         

In [14]: class A(ABC): 
    ...:     pass 
    ...:                                                                                                                             

In [15]: class B:  # don't inherit
    ...:     pass

In [16]: A.register(B)                                                                                                               
Out[16]: __main__.B

In [17]: isinstance(B(), A)                                                                                                          
Out[17]: True

因此,在您的原始代码中,如果您可以使 ClientSession 继承自 abc.ABC(完全没有任何其他更改 )- 然后做:

ClientSession.register(SessionThrottler) 它会起作用(如果 "inferred as" 你的意思与对象类型有关)。

请注意,如果 ClientSession 和其他人有不同的 metaclass,添加 abc.ABC 作为其基础之一将失败并出现 metaclass 冲突。如果您可以更改他们的代码,这仍然是更好的方法:只需创建一个继承自两个元classes 的协作元class,然后您都设置好了:


class Session(metaclass=IDontCare):
    ...


from abc import ABCMeta

class ColaborativeMeta(ABCMeta, Session.__class__):
    pass

class ClientSession(Session, metaclass=ColaborativeMeta):
    ...

类型提示

如果你不需要 "isinstance" 来工作,只需要输入系统相同,那么只需要使用 typing.cast:

import typing as T
...
session = ClientSession()
session_throttled = T.cast(ClientSession, SessionThrottler(session, 4, 2))

该对象在 运行 时未被触及 - 只是同一个对象,但从那时起,mypy 等工具会将其视为 ClientSession 的实例。

最后但同样重要的是,更改 class 名称。

所以,如果 "inferred as" 你并不是说包装的 class 应该被视为一个实例,而是只关心 class 名称在日志中正确显示和这样,您可以将 class __name__ 属性设置为您想要的任何字符串:

class SessionThrottler:
    ...

SessionThrottelr.__name__ = ClientSession.__name__

或者在包装器上使用适当的 __repr__ 方法 class:

class SessionThrottler:
    ...
    def __repr__(self):
        return repr(self._obj)

此解决方案基于修补提供给上下文管理器的对象方法(而不是包装它们)。

import asyncio
import functools
import contextlib

class Counter:
    async def infinite(self):
        cnt = 0
        while True:
            yield cnt
            cnt += 1
            await asyncio.sleep(1)

def limited_infinite(f, limit):
    @functools.wraps(f)
    async def inner(*a, **kw):
        cnt = 0
        async for res in f(*a, **kw):
            yield res
            if cnt == limit:
                break
            cnt += 1
    return inner

@contextlib.contextmanager
def throttler(limit, counter):
    orig = counter.infinite
    counter.infinite = limited_infinite(counter.infinite, limit)
    yield counter
    counter.infinite = orig

async def main():
    with throttler(5, Counter()) as counter:
        async for x in counter.infinite():
            print('res: ', x)

if __name__ == "__main__":
    asyncio.run(main())

对于您的情况,这意味着修补 ClientSession 的每个相关方法(可能仅限于 http 方法)。不确定它是否更好。