如何定义大量的对角矩阵?
How to define plenty of diagonal matricies?
让我们考虑张量:
scale = torch.tensor([[1.0824, 1.0296, 1.0065, 0.9395, 0.9424, 1.0260, 0.9805, 1.0509],
[1.1002, 1.0358, 1.0112, 0.9466, 0.9454, 0.9942, 0.9891, 1.0485],
[1.1060, 1.0157, 1.0216, 0.9544, 0.9378, 1.0160, 0.9671, 1.0240]])
其形状为:
scale.shape
torch.Size([3, 8])
我想要一个形状为 [3, 8, 8]
的张量,其中我有三个使用张量 scale
值的对角矩阵。换句话说,第一个矩阵的对角线仅使用 scale[0]
,第二个 scale[1]
和最后一个 scale[2]
.
我们可以做到无脑:
import torch
temp = torch.tensor([])
for i in range(0, 3):
temp = torch.cat([temp, torch.diag(scale[i])])
temp = temp.view(3, 8, 8)
temp
但我想知道是否还有其他更有效的方法来做到这一点。
我想你在找 diag_embed
:
temp = torch.diag_embed(scale)
例如:
scale = torch.arange(24).view(3,8)
torch.diag_embed(scale)
tensor([[[ 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 1, 0, 0, 0, 0, 0, 0],
[ 0, 0, 2, 0, 0, 0, 0, 0],
[ 0, 0, 0, 3, 0, 0, 0, 0],
[ 0, 0, 0, 0, 4, 0, 0, 0],
[ 0, 0, 0, 0, 0, 5, 0, 0],
[ 0, 0, 0, 0, 0, 0, 6, 0],
[ 0, 0, 0, 0, 0, 0, 0, 7]],
[[ 8, 0, 0, 0, 0, 0, 0, 0],
[ 0, 9, 0, 0, 0, 0, 0, 0],
[ 0, 0, 10, 0, 0, 0, 0, 0],
[ 0, 0, 0, 11, 0, 0, 0, 0],
[ 0, 0, 0, 0, 12, 0, 0, 0],
[ 0, 0, 0, 0, 0, 13, 0, 0],
[ 0, 0, 0, 0, 0, 0, 14, 0],
[ 0, 0, 0, 0, 0, 0, 0, 15]],
[[16, 0, 0, 0, 0, 0, 0, 0],
[ 0, 17, 0, 0, 0, 0, 0, 0],
[ 0, 0, 18, 0, 0, 0, 0, 0],
[ 0, 0, 0, 19, 0, 0, 0, 0],
[ 0, 0, 0, 0, 20, 0, 0, 0],
[ 0, 0, 0, 0, 0, 21, 0, 0],
[ 0, 0, 0, 0, 0, 0, 22, 0],
[ 0, 0, 0, 0, 0, 0, 0, 23]]])
如果你坚持使用循环和torch.cat
,你可以使用列表理解:
temp = torch.stack([torch.diag(s_) for s_ in scale])
让我们考虑张量:
scale = torch.tensor([[1.0824, 1.0296, 1.0065, 0.9395, 0.9424, 1.0260, 0.9805, 1.0509],
[1.1002, 1.0358, 1.0112, 0.9466, 0.9454, 0.9942, 0.9891, 1.0485],
[1.1060, 1.0157, 1.0216, 0.9544, 0.9378, 1.0160, 0.9671, 1.0240]])
其形状为:
scale.shape
torch.Size([3, 8])
我想要一个形状为 [3, 8, 8]
的张量,其中我有三个使用张量 scale
值的对角矩阵。换句话说,第一个矩阵的对角线仅使用 scale[0]
,第二个 scale[1]
和最后一个 scale[2]
.
我们可以做到无脑:
import torch
temp = torch.tensor([])
for i in range(0, 3):
temp = torch.cat([temp, torch.diag(scale[i])])
temp = temp.view(3, 8, 8)
temp
但我想知道是否还有其他更有效的方法来做到这一点。
我想你在找 diag_embed
:
temp = torch.diag_embed(scale)
例如:
scale = torch.arange(24).view(3,8)
torch.diag_embed(scale)
tensor([[[ 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 1, 0, 0, 0, 0, 0, 0],
[ 0, 0, 2, 0, 0, 0, 0, 0],
[ 0, 0, 0, 3, 0, 0, 0, 0],
[ 0, 0, 0, 0, 4, 0, 0, 0],
[ 0, 0, 0, 0, 0, 5, 0, 0],
[ 0, 0, 0, 0, 0, 0, 6, 0],
[ 0, 0, 0, 0, 0, 0, 0, 7]],
[[ 8, 0, 0, 0, 0, 0, 0, 0],
[ 0, 9, 0, 0, 0, 0, 0, 0],
[ 0, 0, 10, 0, 0, 0, 0, 0],
[ 0, 0, 0, 11, 0, 0, 0, 0],
[ 0, 0, 0, 0, 12, 0, 0, 0],
[ 0, 0, 0, 0, 0, 13, 0, 0],
[ 0, 0, 0, 0, 0, 0, 14, 0],
[ 0, 0, 0, 0, 0, 0, 0, 15]],
[[16, 0, 0, 0, 0, 0, 0, 0],
[ 0, 17, 0, 0, 0, 0, 0, 0],
[ 0, 0, 18, 0, 0, 0, 0, 0],
[ 0, 0, 0, 19, 0, 0, 0, 0],
[ 0, 0, 0, 0, 20, 0, 0, 0],
[ 0, 0, 0, 0, 0, 21, 0, 0],
[ 0, 0, 0, 0, 0, 0, 22, 0],
[ 0, 0, 0, 0, 0, 0, 0, 23]]])
如果你坚持使用循环和torch.cat
,你可以使用列表理解:
temp = torch.stack([torch.diag(s_) for s_ in scale])