torch.flatten() 如何从展平的维度中对元素进行排序

How does torch.flatten() order the elements from the flattened dimensions

我有一个形状为 [32,64,64,3] 的 4D 张量,它对应于 [batch, timeframes, frequency_bins, features],我做 tensor.flatten(start_dim=2)(在 PyTorch 中)。我知道形状随后将转换为 [32,64,64*3] --> [batch,timeframes,frequency_bins*features] - 但就 64*3 的新扁平化维度中元素的实际排序而言,前 64 个索引与 [:,:,:,0] 相关] 第二个 64 [:,:,:,1] 和最后一个 64 [:,:,:,2]?

是的,没错,数据的顺序保持不变(底层内存布局也是如此)。这是一个最小的例子:

>>> x = torch.rand(1, 1, 12, 3)

>>> x
tensor([[[[0.9942, 0.1458, 0.7069],
          [0.2749, 0.9886, 0.3127],
          [0.8236, 0.9903, 0.0779],
          [0.0385, 0.2587, 0.7001],
          [0.1113, 0.8011, 0.0176],
          [0.1597, 0.1456, 0.4040],
          [0.2168, 0.6588, 0.7472],
          [0.1607, 0.6133, 0.0700],
          [0.6749, 0.2120, 0.4192],
          [0.8809, 0.8893, 0.4950],
          [0.1695, 0.3175, 0.6653],
          [0.7216, 0.5403, 0.8244]]]])

>>> x.flatten()
tensor([0.9942, 0.1458, 0.7069, 0.2749, 0.9886, 0.3127, 0.8236, 0.9903, 0.0779,
        0.0385, 0.2587, 0.7001, 0.1113, 0.8011, 0.0176, 0.1597, 0.1456, 0.4040,
        0.2168, 0.6588, 0.7472, 0.1607, 0.6133, 0.0700, 0.6749, 0.2120, 0.4192,
        0.8809, 0.8893, 0.4950, 0.1695, 0.3175, 0.6653, 0.7216, 0.5403, 0.8244]))

为了便于理解,让我们首先以最简单的情况为例,其中我们有一个秩为 2 的张量,即正则矩阵。 PyTorch 在所谓的 row-major order 中执行展平,从“最内”轴遍历到“最外”轴。

取一个简单的 2 阶 3x3 数组,我们称它为 A[3, 3]:

[[a, b, c],
 [d, e, f],
 [g, h, i]]

将其从最内轴到最外轴展平将得到 [a, b, c, d, e, f, g, h, i]。让我们称这个扁平数组为 B[3].

A(在索引[i, j]处)和B(在索引k处)中对应元素之间的关系可以很容易地导出为:

k = A.size[1] * i + j

这是因为要到达 [i, j] 处的元素,我们首先向下移动 i 行,每行计算 A.size[1](即数组的宽度)个元素。一旦我们到达行 i,我们需要到达列 j,因此我们添加 j 以获得展平数组中的索引。

例如,元素 eA 中的索引 [1, 1] 处。在 B 中,它将按预期占用索引 3 * 1 + 1 = 4

让我们将同样的想法扩展到阶数为 4 的张量,就像您的情况一样,我们只展平最后两个轴。

同样,采用形状为 (2, 2, 2, 2) 的简单 4 阶张量 A,如下所示:

A =
[[[[ 1,  2],
   [ 3,  4]],

   [[ 5,  6],
   [ 7,  8]]],


   [[[ 9, 10],
   [11, 12]],

   [[13, 14],
   [15, 16]]]]

让我们找到 Atorch.flatten(A, start_dim=2) 的索引之间的关系(我们称其为扁平化版本 B)。

B =
[[[ 0,  1,  2,  3],
  [ 4,  5,  6,  7]],

  [[ 8,  9, 10, 11],
  [12, 13, 14, 15]]]

元素 12 在 A 中的索引 [1, 1, 0, 0]B 中的索引 [1, 1, 0] 处。请注意,轴 0 和 1 处的索引,即 [1, 1] 即使在部分展平后仍保持不变。这是因为这些轴没有展平,因此没有受到影响。

太棒了!因此,我们可以将 AB 的转换表示为

B[i, j, _] = A[i, j, _, _]

我们的任务现在减少到找到 B 的最后一个轴和 A 的最后两个轴之间的关系。但是 A[i, j, _, _] 是一个 2x2 数组,我们已经推导出关系 k = A.size[1] * i + j,

A.size[1] 现在将更改为 A.size[3],因为 3 现在是最后一个轴。但一般关系仍然存在。

填空得到AB中对应元素的关系为:

B[i, j, k] = A[i, j, m, n]

其中 k = A.size[3] * m + n.

我们可以验证这是正确的。元素 14 在 A 中的 [1, 1, 1, 0] 处。并移至 B.

中的 [1, 1, 2 * 1 + 0] = [1, 1, 2]

编辑: 添加示例

以@Molem7b5 的形状为 (1, 4, 4, 3) 的数组 A 为例,来自评论:

A 的内轴 (dim=3) 迭代到外轴 (dim=2) 得到 B 的连续元素。我的意思是:

// Using relation: A[:, :, i, j] == B[:, :, 3 * i + j]

// i = 0, all j
A[:, :, 0, 0] == B[:, :, 0]
A[:, :, 0, 1] == B[:, :, 1]
A[:, :, 0, 2] == B[:, :, 2]
// (Note the consecutive order in B.)

// i = 1, all j
A[:, :, 1, 0] == B[:, :, 3]
A[:, :, 1, 1] == B[:, :, 4]

// and so on until

A[:, :, 3, 2] == B[:, :, 11]

这应该可以让您更好地了解扁平化是如何发生的。如有疑问,请从关系中推断。