TensorLy 中的克罗内克乘积源代码
Kronecker product source code in TensorLy
我正在尝试理解在 TensorLy 中实现的张量 Kronecker 乘积的代码。下面是代码:
def kron(self, a, b):
"""Kronecker product of two tensors.
Parameters
----------
a, b : tensor
The tensors to compute the kronecker product of.
Returns
-------
tensor
"""
s1, s2 = self.shape(a)
s3, s4 = self.shape(b)
a = self.reshape(a, (s1, 1, s2, 1))
b = self.reshape(b, (1, s3, 1, s4))
return self.reshape(a * b, (s1 * s3, s2 * s4))
我明白 self.shape(a)
会给出张量的形状 a
(行、列、切片)。所以我们在 s1
和 s2
中采用 a
的形状,在 s3
和 s4
中采用 b
的形状。
a = self.reshape(a, (s1, 1, s2, 1))
重塑张量 'a',但我发现很难理解什么是 (s1, 1, s2, 1)
以及我们为什么要这样做? (1, s3, 1, s4)
也是如此。另外,我们为什么要这样做 self.reshape(a * b, (s1 * s3, s2 * s4))
?。
这似乎是一个非常开放的问题,但我才刚刚开始,希望得到帮助!
这是一个相当常见的使用广播的技巧。在该对齐方式中将单位维度插入 a
和 b
会发生以下情况:
- 在第一个轴中,
b
被复制 s1
次,以匹配 a
的每一行。
- 在第二个轴中,
a
被复制 s3
次以匹配 b
的每一行。
- 在第三个轴中,
b
被复制 s2
次,以匹配 a
的每一列。
- 在第四个轴中,
a
被复制 s4
次,以匹配 b
的每一列。
当您进行乘法运算时,您最终会得到每个元素组合的 4D 乘积。元素 result[i, j, m, n]
来自 a[i, m] * b[j, n]
最终的 reshape 在内存中获取相同的数据,并在不重新排列数据的情况下组合前两个和最后两个轴。
让我们来看一个简单的例子:
a = [[1, 2, 3],
[2, 3, 4],
[3, 4, 5]]
b = [[6, 7]]
形状由(3, 3)
和(1, 2)
更改为(3, 1, 3, 1)
和(1, 1, 1, 2)
。这不会改变内存中的布局,因此 a
变为
[[[[1], [2], [3]]],
[[[2], [3], [4]]],
[[[3], [4], [5]]]]
b
变为
[[[[6, 7]]]]
结果的形状为 (3, 1, 3, 2)
,看起来像这样:
[[[[1*6, 1*7], [2*6, 2*7], [3*6, 3*7]]],
[[[2*6, 2*7], [3*6, 3*7], [4*6, 4*7]]],
[[[3*6, 3*7], [4*6, 4*7], [5*6, 5*7]]]]
将其重塑为最终结果时,内存布局保持不变,但形状变为 (3*1, 3*2)
:
[[1*6, 1*7, 2*6, 2*7, 3*6, 3*7],
[2*6, 2*7, 3*6, 3*7, 4*6, 4*7],
[3*6, 3*7, 4*6, 4*7, 5*6, 5*7]]
瞧,看看 a
和 b
的克罗内克积。
我正在尝试理解在 TensorLy 中实现的张量 Kronecker 乘积的代码。下面是代码:
def kron(self, a, b):
"""Kronecker product of two tensors.
Parameters
----------
a, b : tensor
The tensors to compute the kronecker product of.
Returns
-------
tensor
"""
s1, s2 = self.shape(a)
s3, s4 = self.shape(b)
a = self.reshape(a, (s1, 1, s2, 1))
b = self.reshape(b, (1, s3, 1, s4))
return self.reshape(a * b, (s1 * s3, s2 * s4))
我明白 self.shape(a)
会给出张量的形状 a
(行、列、切片)。所以我们在 s1
和 s2
中采用 a
的形状,在 s3
和 s4
中采用 b
的形状。
a = self.reshape(a, (s1, 1, s2, 1))
重塑张量 'a',但我发现很难理解什么是 (s1, 1, s2, 1)
以及我们为什么要这样做? (1, s3, 1, s4)
也是如此。另外,我们为什么要这样做 self.reshape(a * b, (s1 * s3, s2 * s4))
?。
这似乎是一个非常开放的问题,但我才刚刚开始,希望得到帮助!
这是一个相当常见的使用广播的技巧。在该对齐方式中将单位维度插入 a
和 b
会发生以下情况:
- 在第一个轴中,
b
被复制s1
次,以匹配a
的每一行。 - 在第二个轴中,
a
被复制s3
次以匹配b
的每一行。 - 在第三个轴中,
b
被复制s2
次,以匹配a
的每一列。 - 在第四个轴中,
a
被复制s4
次,以匹配b
的每一列。
当您进行乘法运算时,您最终会得到每个元素组合的 4D 乘积。元素 result[i, j, m, n]
来自 a[i, m] * b[j, n]
最终的 reshape 在内存中获取相同的数据,并在不重新排列数据的情况下组合前两个和最后两个轴。
让我们来看一个简单的例子:
a = [[1, 2, 3],
[2, 3, 4],
[3, 4, 5]]
b = [[6, 7]]
形状由(3, 3)
和(1, 2)
更改为(3, 1, 3, 1)
和(1, 1, 1, 2)
。这不会改变内存中的布局,因此 a
变为
[[[[1], [2], [3]]],
[[[2], [3], [4]]],
[[[3], [4], [5]]]]
b
变为
[[[[6, 7]]]]
结果的形状为 (3, 1, 3, 2)
,看起来像这样:
[[[[1*6, 1*7], [2*6, 2*7], [3*6, 3*7]]],
[[[2*6, 2*7], [3*6, 3*7], [4*6, 4*7]]],
[[[3*6, 3*7], [4*6, 4*7], [5*6, 5*7]]]]
将其重塑为最终结果时,内存布局保持不变,但形状变为 (3*1, 3*2)
:
[[1*6, 1*7, 2*6, 2*7, 3*6, 3*7],
[2*6, 2*7, 3*6, 3*7, 4*6, 4*7],
[3*6, 3*7, 4*6, 4*7, 5*6, 5*7]]
瞧,看看 a
和 b
的克罗内克积。