pytorch 广播是如何工作的?

How does pytorch broadcasting work?

torch.add(torch.ones(4,1), torch.randn(4))

生成一个张量,大小为:torch.Size([4,4])

有人可以提供这背后的逻辑吗?

PyTorch broadcasting is based on numpy broadcasting semantics which can be understood by reading numpy broadcasting rules or PyTorch broadcasting guide。用一个例子来阐述这个概念将是直观的,可以更好地理解它。所以,请看下面的例子:

In [27]: t_rand
Out[27]: tensor([ 0.23451,  0.34562,  0.45673])

In [28]: t_ones
Out[28]: 
tensor([[ 1.],
        [ 1.],
        [ 1.],
        [ 1.]])

现在 torch.add(t_rand, t_ones),想象一下:

               # shape of (3,)
               tensor([ 0.23451,      0.34562,       0.45673])
      # (4, 1)          | | | |       | | | |        | | | |
      tensor([[ 1.],____+ | | |   ____+ | | |    ____+ | | |
              [ 1.],______+ | |   ______+ | |    ______+ | |
              [ 1.],________+ |   ________+ |    ________+ |
              [ 1.]])_________+   __________+    __________+

应该给出形状为 (4,3) 的张量的输出:

# shape of (4,3)
In [33]: torch.add(t_rand, t_ones)
Out[33]: 
tensor([[ 1.23451,  1.34562,  1.45673],
        [ 1.23451,  1.34562,  1.45673],
        [ 1.23451,  1.34562,  1.45673],
        [ 1.23451,  1.34562,  1.45673]])

此外,请注意,即使我们以与前一个参数相反的顺序传递参数,我们也会得到完全相同的结果:

# shape of (4, 3)
In [34]: torch.add(t_ones, t_rand)
Out[34]: 
tensor([[ 1.23451,  1.34562,  1.45673],
        [ 1.23451,  1.34562,  1.45673],
        [ 1.23451,  1.34562,  1.45673],
        [ 1.23451,  1.34562,  1.45673]])

反正我更喜欢前一种理解方式,更直接直观。


为了图形理解,我挑选了更多示例,列举如下:

Example-1:


Example-2::

TF 分别代表 TrueFalse,表示我们允许广播的维度(来源:Theano)。


Example-3:

这里有一些形状,其中数组 b 被适当地 广播 尝试 以匹配数组的形状 a.

如上所示,广播的b可能仍然不匹配a的形状,因此只要最终广播的形状不匹配,操作a + b就会失败。

a + b

的示例

设:

a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (      4, 1, 6, 7, 8)

第1步:b会补上左边左边!)直到两者的轴数相同:

a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (1, 1, 4, 1, 6, 7, 8)    <-- padded left with 1s

第2步:接下来,如果b的轴长度为1,将重复该轴直到其长度与相应的轴匹配a:

a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (2, 3, 4, 5, 6, 7, 8)    <-- changed 1s to match a

第3步:接下来,如果a的轴长度为1,将重复该轴,直到其长度与相应的轴匹配b:

a.shape = (2, 3, 4, 5, 6, 7, 8)    <-- changed 1s to match b
b.shape = (2, 3, 4, 5, 6, 7, 8)

这些形状匹配,所以a + b将运行成功。 (如果它们不匹配,a + b 将失败。)