在 NumPy 中连接视图

Concatenating views in NumPy

通过 slices/indexes 索引 NumPy 数组创建一个轻量级视图(不复制数据)并允许分配给原始数组的元素。即

import numpy as np
a = np.array([1, 2, 3, 4, 5])
a[2:4] = [6, 7]
print(a)
# [1 2 6 7 5]

但是多个视图如何,我如何连接它们以创建一个更大的视图,该视图仍然分配给原始的第一个数组。例如。对于虚函数 concatenate_views(...):

a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
concatenate_views((a[1:3], a[4:6], a[7:9])) = [11, 12, 13, 14, 15, 16]
print(a)
# should print [1 11 12 4 13 14 7 15 16 10]

当然,我可以为它查看的每个视图创建一个索引列表,只需将切片转换为索引,然后连接这些索引即可。通过这种方式,我将获得串联视图的所有索引,并可以使用这些索引来创建组合视图。但这不是我想要的。我希望 NumPy 保留切片表示的概念,因为所有切片都可能很长,并且将这些切片转换并存储为索引会效率低下。我希望 NumPy 知道所有串联视图的底层切片,以便在内部简单地循环切片范围。

也可以很好地概括问题。不仅连接视图,还允许形成任意 slicing/indexing 操作树,例如连接视图,然后应用一些切片,然后索引,然后切片,然后再次连接。也是N维的slicing/indexing。 IE。所有花哨的东西都可以通过单个未连接的视图完成。

串联视图的要点只是效率。因为我们可以用 N 维整数索引数组(坐标,如 meshgrid)来表示任何视图或切片操作,然后可以使用这个数组来制作源数组的视图。但是如果 numpy 可以保留切片源集而不是整数数组的概念,那么首先它将是轻量级的(内存消耗少得多),其次不是从内存中读取索引,numpy 可以更有效地循环(迭代)每个切片在 C++ 循环中。

通过连接视图,我希望能够以有效的方式将 np.mean(...) 等任何 numpy 操作应用于组合视图。

下面描述了基于2D示例的连接N-D切片视图的完整过程:

    1 Step described below:
    
    2D array slicing using 3 slices for each axis
    
    a,b,c - sizes of "slices" along axis 0
    d,e,f - sizes of "slices" along axis 1
    
    Each "slice" - is either slice(start, stop, step) or 1D array of integer indexes
            
      d e f
     .......
    a.0.1.2.
     .......
    b.3.4.5.
     .......
    c.6.7.8.
     .......

    Above 0 1 2 3 4 5 6 7 8 mean not a single integer but some 2D sub-array.
    Dots (`.`) also mean some 2D sub-arrays.
    
    Sub-views shapes:
    0:(a, d), 1:(a, e), 2:(a, f)
    3:(b, d), 4:(b, e), 5:(b, f)
    6:(c, d), 7:(c, e), 8:(c, f)

    Final aggregated (concatenated) view shape:
    ((a + b + c), (d + e + f))
    containing 2D array
    012
    345
    678
    
    There can be more than one Steps, each next Step applies new sequence of slicing
    to the final view obtained on previous Step. Each Step has different set of sizes
    of slices and different amount of slices per each dimension.
    In general each next Step reduces number of total elements, except the case
    when slices or indexes overlap then you may get more elements but with duplicates.

您可以使用 np.r_ 连接切片对象并分配回索引数组:

a[np.r_[1:3, 4:6, 7:9]] = [11, 12, 13, 14, 15, 16]

print(a)
array([ 1, 11, 12,  4, 13, 14,  7, 15, 16, 10])

更新

根据你的更新,我想你可能需要这样的东西:

from itertools import islice

it = iter([11, 12, 13, 14, 15, 16])
for s in slice(1,3), slice(4,6), slice(7,9):
    a[s] = list(islice(it, s.stop-s.start))

print(a)
array([ 1, 11, 12,  4, 13, 14,  7, 15, 16, 10])

只有在数据类型、步幅和偏移方面连续的视图才能连接起来。这是一种检查方法。这种方式可能不完整,但它说明了它的要点。基本上,如果视图共享一个基础,并且步幅和偏移量对齐以便它们在同一网格上,则可以连接。

本着 TDD 的精神,我将使用以下示例:

x = np.arange(24).reshape(4, 6)

我们(或至少我)希望以下内容可以串联:

a, b = x[:, :4], x[:, 4:]        # Basic case
a, b = x[:, :4:2], x[:, 4::2]    # Strided
a, b = x[:, :4:2], x[:, 2::2]    # Strided overlapping
a, b = x[1:2, 1:4], x[2:4, 1:4]  # Stacked

# Completely reshaped:
a, b = x.ravel()[:12].reshape(3, 4), x.ravel()[12:].reshape(3, 4)
# Equivalent to
a, b = x[:2, :].reshape(3, 4), x[2:, :].reshape(3, 4)

我们希望以下内容可以串联:

a, b = x, np.arange(12).reshape(2, 6)   # Buffer mismatch
a, b = x[0, :].view(np.uint), x[1:, :]  # Dtype mismatch
a, b = x[:, ::2], x[:, ::3]             # Stride mismatch
a, b = x[:, :4], x[:, 4::2]             # Stride mismatch
a, b = x[:, :3], x[:, 4:]               # Overlap mismatch
a, b = x[:, :4:2], x[:, 3::2]           # Overlap mismatch
a, b = x[:-1, :-1], x[1:, 1:]           # Overlap mismatch
a, b = x[:-1, :4], x[:, 4:]             # Shape mismatch

以下内容可以解释为可串联,但在这种情况下不会:

a, b = x, x[1:-1, 1:-1]

想法是一切(dtype、strides、offsets)都必须完全匹配。视图之间只允许一个轴偏移不同,只要它与另一个视图的边缘的距离不超过一步。唯一可能的例外是当一个视图完全包含在另一个视图中时,但我们将在这里忽略这种情况。如果我们对偏移量和步幅使用数组操作,则推广到多个维度应该非常简单。

def cat_slices(a, b):
    if a.base is not b.base:
        raise ValueError('Buffer mismatch')
    if a.dtype != b.dtype:  # I don't thing you can use `is` here in general
        raise ValueError('Dtype mismatch')

    sa = np.array(a.strides)
    sb = np.array(b.strides)

    if (sa != sb).any():
        raise ValueError('Stride mismatch')

    oa = np.byte_bounds(a)[0]
    ob = np.byte_bounds(b)[0]

    if oa > ob:
        a, b = b, a
        oa, ob = ob, oa

    offset = ob - oa

    # Check if you can get to `b` from a by moving along exactly one axis
    # This part works consistently for arrays with internal overlap
    div = np.zeros_like(sa)
    mod = np.ones_like(sa)  # Use ones to auto-flag divide-by zero
    np.divmod(offset, sa, where=sa.astype(bool), out=(div, mod))

    zeros = np.flatnonzero((mod == 0) & (div >= 0) & (div <= a.shape))

    if not zeros.size:
        raise ValueError('Overlap mismatch')

    axis = zeros[0]

    check_shape = np.equal(a.shape, b.shape)
    check_shape[axis] = True

    if not check_shape.all():
        raise ValueError('Shape mismatch')

    shape = list(a.shape)
    shape[axis] = b.shape[axis] + div[axis]

    start = np.byte_bounds(a)[0] - np.byte_bounds(a.base)[0]

    return np.ndarray(shape, dtype=a.dtype, buffer=a.base, offset=start, strides=a.strides)

这个函数不处理的一些事情:

  • 合并标志
  • 广播
  • 处理完全包含在彼此内但具有 multi-axis 偏移量的数组
  • 负步幅

但是,您可以检查它是否 returns 上面显示的所有情况的预期视图(和错误)。在更 production-y 的版本中,我可以设想这种增强 np.concatenate,因此对于失败的情况,它只会复制数据而不是引发错误。