np.add.at 用数组索引

np.add.at indexing with array

我正在研究 cs231n,我很难理解这个索引是如何工作的。鉴于

x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[  1.19034710e-01  -4.65005990e-01   8.93743168e-01  -9.78047129e-01
            -8.88672957e-01  -4.66605091e-01]
         [ -1.38617461e-03  -2.64569728e-01  -3.83712733e-01  -2.61360826e-01
            8.07072009e-01  -5.47607277e-01]
         [ -3.97087458e-01  -4.25187949e-02   2.57931759e-01   7.49565950e-01
           1.37707667e+00   1.77392240e+00]]

       [[ -1.20692745e+00  -8.28111550e-01   6.53041092e-01  -2.31247762e+00
         -1.72370321e+00   2.44308033e+00]
        [ -1.45191870e+00  -3.49328154e-01   6.15445782e-01  -2.84190582e-01
           4.85997687e-02   4.81590106e-01]
        [ -1.14828583e+00  -9.69055406e-01  -1.00773809e+00   3.63553835e-01
          -1.28078363e+00  -2.54448436e+00]]]

他们做的操作是

np.add.at(dW, x, dout)

x 是一个二维数组。索引在这里如何工作?我浏览了 np.ufunc.at 文档,但他们有带有 1d 数组和常量的简单示例:

np.add.at(a, [0, 1, 2, 2], 1)
In [226]: x = [[0,4,1], [3,2,4]]
     ...: dW = np.zeros((5,6),int)

In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]: 
array([[0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0]])

使用此 x 没有任何重复条目,因此 add.at 与使用 += 索引相同。等效地,我们可以读取更改后的值:

In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])

索引以两种方式工作,包括广播:

In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]: 
array([[0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]])

可能的值

值必须是 broadcastable,关于索引:

In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])

In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])

In [118]: dW
Out[118]: 
array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  3,  0,  9,  0],
       [ 0,  0,  4,  0, 11,  0],
       [ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0]])

在这种情况下,索引定义了 (2,3) 形状,因此 (2,3)、(3,)、(2,1) 和标量值有效。 (6,) 没有。

在这种情况下,add.at 将 (2,3) 数组映射到 dW 的 (2,2) 子数组。

最近也很难理解这行代码。希望我得到的可以帮助到你,如果我说错了请指正。

这行代码中的三个数组如下:

x , whose shape is (N,T)
dW,  ---(V,D)
dout ---(N,T,D)

然后我们来到我们要弄清楚发生了什么的代码行

np.add.at(dW, x, dout)

如果你不想知道思维过程。上面的代码相当于:

for row in range(N):
   for col in range(T):
      dW[ x[row,col]  , :] += dout[row,col, :]

这是思考过程:

引用此文档

https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ufunc.at.html

我们知道x是索引数组。所以关键是理解dW[x]。 这是使用另一个数组 (x) 索引数组 (dW) 的概念。如果你不熟悉这个概念,可以看看这个link

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html

Generally speaking, what is returned when index arrays are used is an array with the same shape as the index array, but with the type and values of the array being indexed.

dW[x] 将给我们一个数组,其形状为 (N,T,D),(N,T) 部分来自 x,而 (D) 来自 dW (V,D)。这里注意,x的每个元素都在[0, v)的范围内。

让我们以一些数字为例

x:    np.array([[0,0],[0,0]]) ---- (2,2) N=2, T=2
dW:   np.array([[0,0],[2,2]]) ---- (2,2) V=2, D=2
dout: np.arange(1,9).reshape(2,2,2)  ----(2,2,2) N=2, T=2, D=2

dW[x] should be [ [[0 0] #this comes from the dW's firt row
                  [0 0]]

                  [[0 0]
                   [0 0]] ]

dW[x] add dout 表示添加elemnet item(这里是小技巧,后面会解释)

np.add.at(dW, x, dout) gives 
 [ [16 20]
   [ 2  2] ]

为什么?程序是:

它把[1,2]加到dW的第一行,即[0,0]。

为什么是第一行?因为x[0,0] = 0,表示第一行是dW,dW[0] = dW[0,:] = 第一行。

然后将[3,4]加到dW[0,0]的第一行。 [3,4]=dout[0,1,:]。 又是[0,0],来自dW,x[0,1] = 0,仍然是dW[0]的第一行。

然后将[5,6]添加到dW的第一行。

然后将[7,8]添加到dW的第一行。

所以结果是[1+3+5+7, 2+4+6+8] = [16,20]。因为我们没有触及 dW 的第二行。 dW的第二行保持不变。

诀窍是我们只计算原行一次,可以认为没有缓冲,每一步都在原来的地方播放。

让我们考虑一个基于 cs231n 作业的示例。如果我们谈论的是多个方向,那么使用具体设置会容易得多。

np.random.seed(1)
N, T, V, D = 2, 3, 7, 6
x = np.random.randint(V, size=(N, T))
dW_man = np.zeros((V, D))

dW_man[x].shape, x.shape
((2, 3, 6), (2, 3))

x
array([[5, 3, 4],
   [0, 1, 3]])

dout = np.arange(2*3*6).reshape(dW_man[x].shape)
dout
array([[[ 0,  1,  2,  3,  4,  5],
    [ 6,  7,  8,  9, 10, 11],
    [12, 13, 14, 15, 16, 17]],

   [[18, 19, 20, 21, 22, 23],
    [24, 25, 26, 27, 28, 29],
    [30, 31, 32, 33, 34, 35]]])

dW_man[x] 的行应该是什么?那么 [0, 1, ...] 应该添加到第 5 行,[ 6, 7, ..] - 添加到第 3 行。并且 [30, 31, ...] 应该添加到第 3 行。所以让我们手动计算它。在此 GitHub 要点中查看更多示例和解释:link.

dW_man[5] = dout[0, 0]
dW_man[3] = dout[0, 1]
dW_man[4] = dout[0, 2]

dW_man[0] = dout[1, 0]
dW_man[1] = dout[1, 1]
dW_man[3] = dout[1, 2]

dW_man
array([[18., 19., 20., 21., 22., 23.],
   [24., 25., 26., 27., 28., 29.],
   [ 0.,  0.,  0.,  0.,  0.,  0.],
   [30., 31., 32., 33., 34., 35.],
   [12., 13., 14., 15., 16., 17.],
   [ 0.,  1.,  2.,  3.,  4.,  5.],
   [ 0.,  0.,  0.,  0.,  0.,  0.]])

现在让我们使用 np.add.at

np.random.seed(1)
N, T, V, D = 2, 3, 7, 6
x = np.random.randint(V, size=(N, T))
dW = np.zeros((V, D))
dout = np.arange(2*3*6).reshape(dW[x].shape)
np.add.at(dW, x, dout)

dW
array([[18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29.],
       [ 0.,  0.,  0.,  0.,  0.,  0.],
       [36., 38., 40., 42., 44., 46.],
       [12., 13., 14., 15., 16., 17.],
       [ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 0.,  0.,  0.,  0.,  0.,  0.]])