Pytorch:更新 numpy 数组而不更新相应的张量

Pytorch: Updating numpy array not updating the corresponding tensor

当我运行下面的代码时,

import numpy as np
a = np.ones(3)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

a和b都是2

然而,当我运行:

import numpy as np
a = np.ones(3)
b = torch.from_numpy(a)
a = a+1
print(a)
print(b)

b保持1s,a更新为2s

这是预期的行为吗?

是,如,操作

a = a + 1

创建原始数组的副本 a 并使用 broadcasting 加 1。在加法之后,由于我们将它分配给 a,所以 a 被更新为加法运算的结果。但是,b 仍然共享原始数组 a 的内存(即更新前创建的数组 a。)

所以,我们看到这样的结果:

In [75]: a = np.ones(3)
    ...: b = torch.from_numpy(a)
    ...: a = a+1     # <========= creates copy of `a` and modifies it
    ...: print(a)
    ...: print(b)
    ...: 
[ 2.  2.  2.]

 1
 1
 1
[torch.DoubleTensor of size 3]

但是,看看当你愿意这样做时会发生什么:

In [72]: a = np.ones(3)
    ...: b = torch.from_numpy(a)
    ...: a += 1      # <========== in-place modification of `a`
    ...: print(a)
    ...: print(b)
    ...:

[ 2.  2.  2.]

 2
 2
 2
[torch.DoubleTensor of size 3]

观察 += 操作如何对原始数组进行修改 就地 somearr = somearr + 1 创建数组的 副本 somearray 然后对其进行修改