内存问题:批量构造下三角矩阵(以矢量化方式)

Memory issue: Constructing lower triangular matrix in batches (in a vectorised way)

以下是我要实现的目标:

给定一个如下所示的数组:

[1 0 0 0 0 0 0 0 0]
[1 1 0 0 0 0 0 0 0]
[1 1 1 0 0 0 0 0 0]
[1 1 1 1 0 0 0 0 0]
[1 1 1 1 1 0 0 0 0]
[1 1 1 1 1 1 0 0 0]
[1 1 1 1 1 1 1 0 0]
[1 1 1 1 1 1 1 1 0]
[1 1 1 1 1 1 1 1 1]

请记住,这是一个 [9 x 9] 的小数组,可以一次性初始化。我要构建的数组的形状是[90000 x 90000],因此不能一次性初始化。

将这个 [9 x 9] 数组转换成批后,它看起来像:

    [[1 0 0 0 0 0 0 0 0]
    [1 1 0 0 0 0 0 0 0]
    [1 1 1 0 0 0 0 0 0]],

    [[1 1 1 1 0 0 0 0 0]
    [1 1 1 1 1 0 0 0 0]
    [1 1 1 1 1 1 0 0 0]],

    [[1 1 1 1 1 1 1 0 0]
    [1 1 1 1 1 1 1 1 0]
    [1 1 1 1 1 1 1 1 1]]

我可以使用这些单独的切片来执行操作。

如何批量初始化一个 [90000 x 90000] 的数组,并保持数组中数组的位置?

我试过 tf.linalg.LinearOperatorLowerTrinangular() 算子。这是输出的样子:

>>> arr = tf.linalg.LinearOperatorLowerTriangular(tf.ones((3,3,9)))
>>> arr.to_dense()
<tf.Tensor: id=28, shape=(3, 3, 9), dtype=float32, numpy=
array([[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.]],

       [[1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.]],

       [[1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.]]], dtype=float32)>

在输出中,可以看到每批中的都是从第一个索引开始的。

此外,当我初始化“arr”时:

>>> arr = tf.linalg.LinearOperatorLowerTriangular(tf.ones((3,3,9)))

我用过 tf.ones(),在我的例子中,我什至不能用它来初始化 90000 x 90000 数组。

请提出construct/initialise这样一个矩阵的有效方法?

编辑:

我到底在找什么?关于如何制作 90000 x 90000 大矩阵的 chunks/sub 矩阵的逻辑。例如,我可以创建 10 个大小为 9000 x 90000 的子矩阵。但是,我如何初始化这 10 个子矩阵以查看就像一个单一的下三角矩阵。 我将使用这些子矩阵来执行向量矩阵乘法。一旦我找到一种方法为这些子矩阵赋值(以一种方式,当附加在一起时看起来像一个单一的下三角矩阵)然后在执行向量矩阵乘法之后,我可以将所有 10 个结果矩阵附加到一个。

我不完全确定这是否是您需要的,但这里有一种在 TensorFlow 中通过循环中的块计算类似内容的方法(此处 2.x,但应该类似于 1.x):

import tensorflow as tf

@tf.function
def mult_tril_chunks(vector, num_chunks=None, chunk_size=None):
    # Function accepts either number of chunks or size of each chunk
    # Last chunk may have different size
    if num_chunks is None == chunk_size is None:
        raise ValueError('one and only one of num_chunks and chunk_size must be given.')
    vector = tf.convert_to_tensor(vector)
    # Compute number of chunks or chunk size depending on given parameters
    s = tf.shape(vector)[0]
    if chunk_size is not None:
        chunk_size = tf.dtypes.cast(chunk_size, s.dtype)
        num_chunks = (s // chunk_size) + tf.dtypes.cast(s % chunk_size != 0, s.dtype)
    else:
        n = tf.dtypes.cast(num_chunks, s.dtype)
        chunk_size = s // n + tf.dtypes.cast(s % n != 0, s.dtype)
    # If you know all chunks will always have the same size
    # you can instead pass element_shape=[chunk_size]
    ta = tf.TensorArray(vector.dtype, size=num_chunks, element_shape=[None],
                        infer_shape=False)
    # Do computation in a loop
    _, ta = tf.while_loop(
        lambda i, ta: i < num_chunks,
        lambda i, ta: _mult_tril_chunks_loop(i, ta, vector, s, chunk_size),
        [0, ta]
    )
    # Return concatenated result
    return ta.concat()

def _mult_tril_chunks_loop(i, ta, vector, s, chunk_size):
    # Chunk bounds
    start = i * chunk_size
    end = tf.math.minimum(start + chunk_size, s)
    # Make slice of lower triangular matrix
    r = tf.range(s)
    tril_slice = r <= tf.expand_dims(tf.range(start, end), axis=1)
    tril_slice = tf.dtypes.cast(tril_slice, vector.dtype)
    # Compute product
    mult_slice = tf.linalg.matvec(tril_slice, vector)
    return i + 1, ta.write(i, mult_slice)

# Test
vector = tf.range(10)
# Result with normal computation
tril = tf.linalg.band_part(tf.ones((10, 10), vector.dtype), -1, 0)
res = tf.linalg.matvec(tril, vector)
print(res.numpy())
# [ 0  1  3  6 10 15 21 28 36 45]
# Giving number of chunks
res2 = mult_tril_chunks(vector, num_chunks=3)
print(res2.numpy())
# [ 0  1  3  6 10 15 21 28 36 45]
# Giving chunk size
res3 = mult_tril_chunks(vector, chunk_size=4)
print(res3.numpy())
# [ 0  1  3  6 10 15 21 28 36 45]

由于您关心内存使用情况,因此可能值得检查 tf.while_loop 的一些更高级的参数,例如 运行.[=13= 的并行迭代次数]