pytorch 中的张量幂和乘法

Tensor power and multiplication in pytorch

我有一个矩阵 A 和一个大小为 (1,3) 的张量 b - 所以一个大小为 3 的向量。

我要计算

C = b1 * A + b2 * A^2 + b3 * A^3 其中 ^n 是 A 的 n 次方。

最后,C 的形状应该与 A 相同。我怎样才能高效地做到这一点?

让我们试试:

A = torch.ones(1,2,3)
b_vals = torch.tensor([2,3,4])
powers = torch.tensor([1,2,3])

C = (A[...,None]**powers + b_vals).sum(-1)

输出:

tensor([[[12., 12., 12.],
         [12., 12., 12.]]])