深度复制嵌套迭代(或改进 itertools.tee for iterable of iterables)

deep copy nested iterable (or improved itertools.tee for iterable of iterables)

前言

我有一个测试,我正在使用嵌套的可迭代对象(通过 nested iterable 我的意思是只有可迭代对象作为元素的可迭代对象)。

作为测试级联考虑

from itertools import tee
from typing import (Any,
                    Iterable)


def foo(nested_iterable: Iterable[Iterable[Any]]) -> Any:
    ...


def test_foo(nested_iterable: Iterable[Iterable[Any]]) -> None:
    original, target = tee(nested_iterable)  # this doesn't copy iterators elements

    result = foo(target)

    assert is_contract_satisfied(result, original)


def is_contract_satisfied(result: Any,
                          original: Iterable[Iterable[Any]]) -> bool:
    ...

例如foo可能是简单恒等函数

def foo(nested_iterable: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
    return nested_iterable

合同只是检查扁平化的可迭代对象是否具有相同的元素

from itertools import (chain,
                       starmap,
                       zip_longest)
from operator import eq
...
flatten = chain.from_iterable


def is_contract_satisfied(result: Iterable[Iterable[Any]],
                          original: Iterable[Iterable[Any]]) -> bool:
    return all(starmap(eq,
                       zip_longest(flatten(result), flatten(original),
                                   # we're assuming that ``object()``
                                   # will create some unique object
                                   # not presented in any of arguments
                                   fillvalue=object())))

但是如果 nested_iterable 元素中的某些元素是迭代器,它可能会耗尽,因为 tee 正在制作浅拷贝,而不是深拷贝,即对于给定的 foois_contract_satisfied 下一条语句

>>> test_foo([iter(range(10))])

导致可预测

Traceback (most recent call last):
  ...
    test_foo([iter(range(10))])
  File "...", line 19, in test_foo
    assert is_contract_satisfied(result, original)
AssertionError

问题

如何深度复制任意嵌套的可迭代对象?

备注

我知道 copy.deepcopy function,但它不适用于文件对象。

解决您的问题:如何深度复制嵌套的可迭代对象?

您可以使用标准库中的deepcopy

>>> from copy import deepcopy
>>> 
>>> ni = [1, [2,3,4]]
>>> ci = deepcopy(ni)
>>> ci[1][0] = "Modified"
>>> ci
[1, ['Modified', 3, 4]]
>>> ni
[1, [2,3,4]]

更新

@Azat Ibrakov 说:你正在处理序列,例如尝试深度复制文件对象(提示:它会失败)

不,对文件对象进行深度复制,不会失败,您可以对文件对象进行深度复制,演示:

import copy

with open('example.txt', 'w') as f:
     f.writelines(["{}\n".format(i) for i in range(100)])

with open('example.txt', 'r') as f:
    l = [1, [f]]
    c = copy.deepcopy(l)
    print(isinstance(c[1][0], file))  # Prints  True.
    print("\n".join(dir(c[1][0])))

打印:

True
__class__
__delattr__
__doc__
__enter__
__exit__
__format__
__getattribute__
...
write
writelines
xreadlines

问题出在概念上。

根据Python Iterator协议,执行next函数获取某些容器包含的项见docs here.

在遍历整个迭代器(执行 next() 直到引发 StopIteration 异常)之前,您不会拥有实现迭代器协议的对象的所有项目(作为文件对象)。

那是因为你无法确定迭代器next (__next__ for Python 2.x) 方法的执行结果

参见以下示例:

import random

class RandomNumberIterator:

    def __init__(self):
        self.count = 0
        self.internal_it = range(10)  # For later demostration on deepcopy

    def __iter__(self):
        return self

    def next(self):
        self.count += 1
        if self.count == 10:
            raise StopIteration
        return random.randint(0, 1000)

ri = RandomNumberIterator()

for i in ri:
    print(i)  # This will print randor numbers each time.
              # Can you come out with some sort of mechanism to be able
              # to copy **THE CONTENT** of the `ri` iterator? 

你还可以:

from copy import deepcopy

cri = deepcopy(ri)

for i in cri.internal_it:
    print(i)   # Will print numbers 0..9
               # Deepcopy on ri successful!

A file object is an especial case here, there are file handlers involved, before, you see you can deepcopy a file object, but it will have closed state.

备选。

您可以在可迭代对象上调用 list,这将自动计算可迭代对象,然后您将能够再次测试可迭代对象的内容

返回文件:

with open('example.txt', 'w') as f:
         f.writelines(["{}\n".format(i) for i in range(5)])

with open('example.txt', 'r') as f:
    print(list(f))  # Prints ['0\n', '1\n', '2\n', '3\n', '4\n']

所以,继续

您可以深度复制嵌套的可迭代对象,但是,您不能在复制可迭代对象时对其求值,这没有任何意义(记住 RandomNumberIterator)。

如果你需要测试可迭代对象 CONTENT 你需要评估它们。

天真的解决方案

简单的算法是

  1. 执行原始嵌套迭代的元素复制。
  2. 创建 n 个元素副本。
  3. 获取每个独立副本的相关坐标。

可以像

那样实现
from itertools import tee
from operator import itemgetter
from typing import (Any,
                    Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


def copy_nested_iterable(nested_iterable: Iterable[Iterable[Domain]],
                         *,
                         count: int = 2
                         ) -> Tuple[Iterable[Iterable[Domain]], ...]:
    def shallow_copy(iterable: Iterable[Domain]) -> Tuple[Iterable[Domain], ...]:
        return tee(iterable, count)

    copies = shallow_copy(map(shallow_copy, nested_iterable))
    return tuple(map(itemgetter(index), iterables)
                 for index, iterables in enumerate(copies))

优点:

  • 非常容易阅读和解释。

缺点:

  • 如果我们想将我们的方法扩展到具有更高嵌套级别的迭代器(如嵌套迭代器的迭代器等),这种方法看起来没有帮助。

我们可以做得更好。

改进的解决方案

如果我们看itertools.tee function documentation, it contains Python recipe, which with help of functools.singledispatch decorator可以改写成

from collections import (abc,
                         deque)
from functools import singledispatch
from itertools import repeat
from typing import (Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


@functools.singledispatch
def copy(object_: Domain,
         *,
         count: int) -> Iterable[Domain]:
    raise TypeError('Unsupported object type: {type}.'
                    .format(type=type(object_)))

# handle general case
@copy.register(object)
# immutable strings represent a special kind of iterables
# that can be copied by simply repeating
@copy.register(bytes)
@copy.register(str)
# mappings cannot be copied as other iterables
# since they are iterable only by key
@copy.register(abc.Mapping)
def copy_object(object_: Domain,
                *,
                count: int) -> Iterable[Domain]:
    return itertools.repeat(object_, count)


@copy.register(abc.Iterable)
def copy_iterable(object_: Iterable[Domain],
                  *,
                  count: int = 2) -> Tuple[Iterable[Domain], ...]:
    iterator = iter(object_)
    # we are using `itertools.repeat` instead of `range` here
    # due to efficiency of the former
    # more info at
    # 
    queues = [deque() for _ in repeat(None, count)]

    def replica(queue: deque) -> Iterable[Domain]:
        while True:
            if not queue:
                try:
                    element = next(iterator)
                except StopIteration:
                    return
                element_copies = copy(element,
                                           count=count)
                for sub_queue, element_copy in zip(queues, element_copies):
                    sub_queue.append(element_copy)
            yield queue.popleft()

    return tuple(replica(queue) for queue in queues)

优点:

  • 处理更深层次的嵌套,甚至处理混合元素,例如同一层次上的可迭代对象和不可迭代对象,
  • 可以针对用户定义的结构进行扩展(例如,制作它们的独立深拷贝)。

缺点:

  • 可读性较差(但据我们所知"practicality beats purity"),
  • 提供了一些与调度相关的开销(但这没关系,因为它基于具有 O(1) 复杂性的字典查找)。

测试

准备

让我们定义我们的嵌套迭代如下

nested_iterable = [range(10 ** index) for index in range(1, 7)]

由于迭代器的创建与底层副本性能无关,让我们为迭代器耗尽定义函数(描述 here

exhaust_iterable = deque(maxlen=0).extend

时间

使用timeit

import timeit

def naive(): exhaust_iterable(copy_nested_iterable(nested_iterable))

def improved(): exhaust_iterable(copy_iterable(nested_iterable))

print('naive approach:', min(timeit.repeat(naive)))
print('improved approach:', min(timeit.repeat(improved)))

我的笔记本电脑 Windows 10 x64 in Python 3.5.4

naive approach: 5.1863865
improved approach: 3.5602296000000013

内存

使用memory_profiler package

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.6 MiB     51.4 MiB       result = list(flatten(flatten(copy_nested_iterable(nested_iterable))))

对于"naive"方法和

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.7 MiB     51.4 MiB       result = list(flatten(flatten(copy_iterable(nested_iterable))))

对于"improved"一个。

注意:我运行了不同的脚本,因为一次运行它们不具有代表性,因为第二条语句将重用之前在后台创建的 int 对象。


结论

正如我们所见,这两个函数具有相似的性能,但最后一个函数支持更深层次的嵌套并且看起来非常可扩展。

广告

我已经从 0.4.0 版本的 lz package 添加了 "improved" 解决方案,可以像

一样使用
>>> from lz.replication import replicate
>>> iterable = iter(range(5))
>>> list(map(list, replicate(iterable,
                             count=3)))
[[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]

它使用 hypothesis framework 基于 属性 进行了测试,因此我们可以确定它按预期工作。