pytorch 广播是如何工作的?
How does pytorch broadcasting work?
torch.add(torch.ones(4,1), torch.randn(4))
生成一个张量,大小为:torch.Size([4,4])
。
有人可以提供这背后的逻辑吗?
PyTorch broadcasting
is based on numpy broadcasting semantics which can be understood by reading numpy broadcasting rules
or PyTorch broadcasting guide。用一个例子来阐述这个概念将是直观的,可以更好地理解它。所以,请看下面的例子:
In [27]: t_rand
Out[27]: tensor([ 0.23451, 0.34562, 0.45673])
In [28]: t_ones
Out[28]:
tensor([[ 1.],
[ 1.],
[ 1.],
[ 1.]])
现在 torch.add(t_rand, t_ones)
,想象一下:
# shape of (3,)
tensor([ 0.23451, 0.34562, 0.45673])
# (4, 1) | | | | | | | | | | | |
tensor([[ 1.],____+ | | | ____+ | | | ____+ | | |
[ 1.],______+ | | ______+ | | ______+ | |
[ 1.],________+ | ________+ | ________+ |
[ 1.]])_________+ __________+ __________+
应该给出形状为 (4,3)
的张量的输出:
# shape of (4,3)
In [33]: torch.add(t_rand, t_ones)
Out[33]:
tensor([[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673]])
此外,请注意,即使我们以与前一个参数相反的顺序传递参数,我们也会得到完全相同的结果:
# shape of (4, 3)
In [34]: torch.add(t_ones, t_rand)
Out[34]:
tensor([[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673]])
反正我更喜欢前一种理解方式,更直接直观。
为了图形理解,我挑选了更多示例,列举如下:
Example-1:
Example-2:
:
T
和 F
分别代表 True
和 False
,表示我们允许广播的维度(来源:Theano)。
Example-3:
这里有一些形状,其中数组 b
被适当地 广播 以 尝试 以匹配数组的形状 a
.
如上所示,广播的b
可能仍然不匹配a
的形状,因此只要最终广播的形状不匹配,操作a + b
就会失败。
a + b
的示例
设:
a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = ( 4, 1, 6, 7, 8)
第1步:b
会补上左边(只左边!)直到两者的轴数相同:
a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (1, 1, 4, 1, 6, 7, 8) <-- padded left with 1s
第2步:接下来,如果b
的轴长度为1
,将重复该轴直到其长度与相应的轴匹配a
:
a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (2, 3, 4, 5, 6, 7, 8) <-- changed 1s to match a
第3步:接下来,如果a
的轴长度为1
,将重复该轴,直到其长度与相应的轴匹配b
:
a.shape = (2, 3, 4, 5, 6, 7, 8) <-- changed 1s to match b
b.shape = (2, 3, 4, 5, 6, 7, 8)
这些形状匹配,所以a + b
将运行成功。 (如果它们不匹配,a + b
将失败。)
torch.add(torch.ones(4,1), torch.randn(4))
生成一个张量,大小为:torch.Size([4,4])
。
有人可以提供这背后的逻辑吗?
PyTorch broadcasting
is based on numpy broadcasting semantics which can be understood by reading numpy broadcasting rules
or PyTorch broadcasting guide。用一个例子来阐述这个概念将是直观的,可以更好地理解它。所以,请看下面的例子:
In [27]: t_rand
Out[27]: tensor([ 0.23451, 0.34562, 0.45673])
In [28]: t_ones
Out[28]:
tensor([[ 1.],
[ 1.],
[ 1.],
[ 1.]])
现在 torch.add(t_rand, t_ones)
,想象一下:
# shape of (3,)
tensor([ 0.23451, 0.34562, 0.45673])
# (4, 1) | | | | | | | | | | | |
tensor([[ 1.],____+ | | | ____+ | | | ____+ | | |
[ 1.],______+ | | ______+ | | ______+ | |
[ 1.],________+ | ________+ | ________+ |
[ 1.]])_________+ __________+ __________+
应该给出形状为 (4,3)
的张量的输出:
# shape of (4,3)
In [33]: torch.add(t_rand, t_ones)
Out[33]:
tensor([[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673]])
此外,请注意,即使我们以与前一个参数相反的顺序传递参数,我们也会得到完全相同的结果:
# shape of (4, 3)
In [34]: torch.add(t_ones, t_rand)
Out[34]:
tensor([[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673]])
反正我更喜欢前一种理解方式,更直接直观。
为了图形理解,我挑选了更多示例,列举如下:
Example-1:
Example-2:
:
T
和 F
分别代表 True
和 False
,表示我们允许广播的维度(来源:Theano)。
Example-3:
这里有一些形状,其中数组 b
被适当地 广播 以 尝试 以匹配数组的形状 a
.
如上所示,广播的b
可能仍然不匹配a
的形状,因此只要最终广播的形状不匹配,操作a + b
就会失败。
a + b
的示例
设:
a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = ( 4, 1, 6, 7, 8)
第1步:b
会补上左边(只左边!)直到两者的轴数相同:
a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (1, 1, 4, 1, 6, 7, 8) <-- padded left with 1s
第2步:接下来,如果b
的轴长度为1
,将重复该轴直到其长度与相应的轴匹配a
:
a.shape = (2, 3, 4, 5, 1, 1, 1)
b.shape = (2, 3, 4, 5, 6, 7, 8) <-- changed 1s to match a
第3步:接下来,如果a
的轴长度为1
,将重复该轴,直到其长度与相应的轴匹配b
:
a.shape = (2, 3, 4, 5, 6, 7, 8) <-- changed 1s to match b
b.shape = (2, 3, 4, 5, 6, 7, 8)
这些形状匹配,所以a + b
将运行成功。 (如果它们不匹配,a + b
将失败。)