torch.flatten() 如何从展平的维度中对元素进行排序
How does torch.flatten() order the elements from the flattened dimensions
我有一个形状为 [32,64,64,3]
的 4D 张量,它对应于 [batch, timeframes, frequency_bins, features]
,我做 tensor.flatten(start_dim=2)
(在 PyTorch 中)。我知道形状随后将转换为 [32,64,64*3] --> [batch,timeframes,frequency_bins*features]
- 但就 64*3
的新扁平化维度中元素的实际排序而言,前 64 个索引与 [:,:,:,0]
相关] 第二个 64 [:,:,:,1]
和最后一个 64 [:,:,:,2]
?
是的,没错,数据的顺序保持不变(底层内存布局也是如此)。这是一个最小的例子:
>>> x = torch.rand(1, 1, 12, 3)
>>> x
tensor([[[[0.9942, 0.1458, 0.7069],
[0.2749, 0.9886, 0.3127],
[0.8236, 0.9903, 0.0779],
[0.0385, 0.2587, 0.7001],
[0.1113, 0.8011, 0.0176],
[0.1597, 0.1456, 0.4040],
[0.2168, 0.6588, 0.7472],
[0.1607, 0.6133, 0.0700],
[0.6749, 0.2120, 0.4192],
[0.8809, 0.8893, 0.4950],
[0.1695, 0.3175, 0.6653],
[0.7216, 0.5403, 0.8244]]]])
>>> x.flatten()
tensor([0.9942, 0.1458, 0.7069, 0.2749, 0.9886, 0.3127, 0.8236, 0.9903, 0.0779,
0.0385, 0.2587, 0.7001, 0.1113, 0.8011, 0.0176, 0.1597, 0.1456, 0.4040,
0.2168, 0.6588, 0.7472, 0.1607, 0.6133, 0.0700, 0.6749, 0.2120, 0.4192,
0.8809, 0.8893, 0.4950, 0.1695, 0.3175, 0.6653, 0.7216, 0.5403, 0.8244]))
为了便于理解,让我们首先以最简单的情况为例,其中我们有一个秩为 2 的张量,即正则矩阵。 PyTorch 在所谓的 row-major order 中执行展平,从“最内”轴遍历到“最外”轴。
取一个简单的 2 阶 3x3 数组,我们称它为 A[3, 3]
:
[[a, b, c],
[d, e, f],
[g, h, i]]
将其从最内轴到最外轴展平将得到 [a, b, c, d, e, f, g, h, i]
。让我们称这个扁平数组为 B[3]
.
A
(在索引[i, j]
处)和B
(在索引k
处)中对应元素之间的关系可以很容易地导出为:
k = A.size[1] * i + j
这是因为要到达 [i, j]
处的元素,我们首先向下移动 i
行,每行计算 A.size[1]
(即数组的宽度)个元素。一旦我们到达行 i
,我们需要到达列 j
,因此我们添加 j
以获得展平数组中的索引。
例如,元素 e
在 A
中的索引 [1, 1]
处。在 B
中,它将按预期占用索引 3 * 1 + 1 = 4
。
让我们将同样的想法扩展到阶数为 4 的张量,就像您的情况一样,我们只展平最后两个轴。
同样,采用形状为 (2, 2, 2, 2)
的简单 4 阶张量 A
,如下所示:
A =
[[[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]]],
[[[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16]]]]
让我们找到 A
和 torch.flatten(A, start_dim=2)
的索引之间的关系(我们称其为扁平化版本 B
)。
B =
[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]]
元素 12 在 A
中的索引 [1, 1, 0, 0]
和 B
中的索引 [1, 1, 0]
处。请注意,轴 0 和 1 处的索引,即 [1, 1]
即使在部分展平后仍保持不变。这是因为这些轴没有展平,因此没有受到影响。
太棒了!因此,我们可以将 A
到 B
的转换表示为
B[i, j, _] = A[i, j, _, _]
我们的任务现在减少到找到 B
的最后一个轴和 A
的最后两个轴之间的关系。但是 A[i, j, _, _]
是一个 2x2 数组,我们已经推导出关系 k = A.size[1] * i + j
,
A.size[1]
现在将更改为 A.size[3]
,因为 3 现在是最后一个轴。但一般关系仍然存在。
填空得到A
和B
中对应元素的关系为:
B[i, j, k] = A[i, j, m, n]
其中 k = A.size[3] * m + n
.
我们可以验证这是正确的。元素 14 在 A
中的 [1, 1, 1, 0]
处。并移至 B
.
中的 [1, 1, 2 * 1 + 0] = [1, 1, 2]
编辑: 添加示例
以@Molem7b5 的形状为 (1, 4, 4, 3)
的数组 A
为例,来自评论:
从 A
的内轴 (dim=3
) 迭代到外轴 (dim=2
) 得到 B
的连续元素。我的意思是:
// Using relation: A[:, :, i, j] == B[:, :, 3 * i + j]
// i = 0, all j
A[:, :, 0, 0] == B[:, :, 0]
A[:, :, 0, 1] == B[:, :, 1]
A[:, :, 0, 2] == B[:, :, 2]
// (Note the consecutive order in B.)
// i = 1, all j
A[:, :, 1, 0] == B[:, :, 3]
A[:, :, 1, 1] == B[:, :, 4]
// and so on until
A[:, :, 3, 2] == B[:, :, 11]
这应该可以让您更好地了解扁平化是如何发生的。如有疑问,请从关系中推断。
我有一个形状为 [32,64,64,3]
的 4D 张量,它对应于 [batch, timeframes, frequency_bins, features]
,我做 tensor.flatten(start_dim=2)
(在 PyTorch 中)。我知道形状随后将转换为 [32,64,64*3] --> [batch,timeframes,frequency_bins*features]
- 但就 64*3
的新扁平化维度中元素的实际排序而言,前 64 个索引与 [:,:,:,0]
相关] 第二个 64 [:,:,:,1]
和最后一个 64 [:,:,:,2]
?
是的,没错,数据的顺序保持不变(底层内存布局也是如此)。这是一个最小的例子:
>>> x = torch.rand(1, 1, 12, 3)
>>> x
tensor([[[[0.9942, 0.1458, 0.7069],
[0.2749, 0.9886, 0.3127],
[0.8236, 0.9903, 0.0779],
[0.0385, 0.2587, 0.7001],
[0.1113, 0.8011, 0.0176],
[0.1597, 0.1456, 0.4040],
[0.2168, 0.6588, 0.7472],
[0.1607, 0.6133, 0.0700],
[0.6749, 0.2120, 0.4192],
[0.8809, 0.8893, 0.4950],
[0.1695, 0.3175, 0.6653],
[0.7216, 0.5403, 0.8244]]]])
>>> x.flatten()
tensor([0.9942, 0.1458, 0.7069, 0.2749, 0.9886, 0.3127, 0.8236, 0.9903, 0.0779,
0.0385, 0.2587, 0.7001, 0.1113, 0.8011, 0.0176, 0.1597, 0.1456, 0.4040,
0.2168, 0.6588, 0.7472, 0.1607, 0.6133, 0.0700, 0.6749, 0.2120, 0.4192,
0.8809, 0.8893, 0.4950, 0.1695, 0.3175, 0.6653, 0.7216, 0.5403, 0.8244]))
为了便于理解,让我们首先以最简单的情况为例,其中我们有一个秩为 2 的张量,即正则矩阵。 PyTorch 在所谓的 row-major order 中执行展平,从“最内”轴遍历到“最外”轴。
取一个简单的 2 阶 3x3 数组,我们称它为 A[3, 3]
:
[[a, b, c],
[d, e, f],
[g, h, i]]
将其从最内轴到最外轴展平将得到 [a, b, c, d, e, f, g, h, i]
。让我们称这个扁平数组为 B[3]
.
A
(在索引[i, j]
处)和B
(在索引k
处)中对应元素之间的关系可以很容易地导出为:
k = A.size[1] * i + j
这是因为要到达 [i, j]
处的元素,我们首先向下移动 i
行,每行计算 A.size[1]
(即数组的宽度)个元素。一旦我们到达行 i
,我们需要到达列 j
,因此我们添加 j
以获得展平数组中的索引。
例如,元素 e
在 A
中的索引 [1, 1]
处。在 B
中,它将按预期占用索引 3 * 1 + 1 = 4
。
让我们将同样的想法扩展到阶数为 4 的张量,就像您的情况一样,我们只展平最后两个轴。
同样,采用形状为 (2, 2, 2, 2)
的简单 4 阶张量 A
,如下所示:
A =
[[[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]]],
[[[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16]]]]
让我们找到 A
和 torch.flatten(A, start_dim=2)
的索引之间的关系(我们称其为扁平化版本 B
)。
B =
[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]]
元素 12 在 A
中的索引 [1, 1, 0, 0]
和 B
中的索引 [1, 1, 0]
处。请注意,轴 0 和 1 处的索引,即 [1, 1]
即使在部分展平后仍保持不变。这是因为这些轴没有展平,因此没有受到影响。
太棒了!因此,我们可以将 A
到 B
的转换表示为
B[i, j, _] = A[i, j, _, _]
我们的任务现在减少到找到 B
的最后一个轴和 A
的最后两个轴之间的关系。但是 A[i, j, _, _]
是一个 2x2 数组,我们已经推导出关系 k = A.size[1] * i + j
,
A.size[1]
现在将更改为 A.size[3]
,因为 3 现在是最后一个轴。但一般关系仍然存在。
填空得到A
和B
中对应元素的关系为:
B[i, j, k] = A[i, j, m, n]
其中 k = A.size[3] * m + n
.
我们可以验证这是正确的。元素 14 在 A
中的 [1, 1, 1, 0]
处。并移至 B
.
[1, 1, 2 * 1 + 0] = [1, 1, 2]
编辑: 添加示例
以@Molem7b5 的形状为 (1, 4, 4, 3)
的数组 A
为例,来自评论:
从 A
的内轴 (dim=3
) 迭代到外轴 (dim=2
) 得到 B
的连续元素。我的意思是:
// Using relation: A[:, :, i, j] == B[:, :, 3 * i + j]
// i = 0, all j
A[:, :, 0, 0] == B[:, :, 0]
A[:, :, 0, 1] == B[:, :, 1]
A[:, :, 0, 2] == B[:, :, 2]
// (Note the consecutive order in B.)
// i = 1, all j
A[:, :, 1, 0] == B[:, :, 3]
A[:, :, 1, 1] == B[:, :, 4]
// and so on until
A[:, :, 3, 2] == B[:, :, 11]
这应该可以让您更好地了解扁平化是如何发生的。如有疑问,请从关系中推断。