我如何判断一个函数是否可以在上下文管理器中使用?

how can I tell if a function can be used in a context manager?

为了改变numpy数组的打印精度x,我一直在使用这个:

with np.printoptions(precision=2, suppress=True):
    print(x)

我想对火炬张量 aTensor 做同样的事情,但下面的方法不起作用(我得到 AttributeError: __enter__ ):

with torch.set_printoptions(precision=2):
    print(aTensor)

我是 python 的新手,环顾四周并阅读了要在上下文管理器中使用的函数,它需要具有 __enter____exit__。但是当我尝试检查时,我发现 np.printoptionstorch.set_printoptions 都没有 __enter__ 作为属性:hasattr(np.printoptions, "__enter__") returns False 和相同hasattr(torch.set_printoptions, "__enter__").

但是前者可以在上下文管理器中使用,而后者不能。这是为什么?

火炬张量打印精度的直接问题我可以通过更改精度、打印然后再更改精度来解决。我对学习上下文管理器的基础知识更感兴趣。提前致谢。

np.printoptions 没有 __enter__ 属性的原因是它是 returns 上下文管理器的函数;它本身不是上下文管理器。

>>> from contextlib import AbstractContextManager
>>> cm = np.printoptions(precision=2, suppress=True)
>>> cm
<contextlib._GeneratorContextManager object at 0x636f6e747874>
>>> isinstance(cm, AbstractContextManager)
True

请注意,并非所有上下文管理器都将成为 contextlib._GeneratorContextManager 对象; numpy 恰好使用标准库中的 contextlib 来创建这个上下文管理器。

回答你的字面问题,最后一行代码将检查某物是否是上下文管理器;您可以检查它是否是 contextlib.AbstractContextManager 的实例。如果您需要检查代码中的某些内容是否是上下文管理器,那么您应该这样做。如果您只需要根据自己的知识临时检查并且不愿意那样做,那么您可以:检查它具有 __enter____exit__ 属性的 REPL,方法是尝试自动完成它们或使用 dir();尝试将其用作上下文管理器;或检查 documentation/implementation.


np.printoptions 不同,torch.set_printoptions 甚至没有 return 上下文管理器,这就是为什么你得到 AttributeError。但是,您可以创建自己的上下文管理器来为您处理 torch.set_printoptions,然后您可以像 np.printoptions 一样使用它。这是一个例子;我还没有对此进行测试,但任何潜在的问题都可以得到解决。可以看到相关代码here.

import contextlib
import copy

import torch

@contextlib.contextmanager
def torch_set_printoptions_cm(*args, **kwargs):
    try:
        # be warned, torch._tensor_str is a private module,
        # not bound by API guarantees
        original_options = torch._tensor_str.PRINT_OPTS
        torch._tensor_str.PRINT_OPTS = copy.copy(original_options)

        torch.set_printoptions(*args, **kwargs)
        yield torch._tensor_str.PRINT_OPTS
    finally:
        torch._tensor_str.PRINT_OPTS = original_options

然后您可以按照您的尝试进行操作:

with torch_set_printoptions_cm(precision=2):
    print(aTensor)