为神经网络实施动量权重更新

Implementing momentum weight update for neural network

我正在关注 mnielsen 的在线 book. I'm trying to implement momentum weight update as defined here to his code here。总体思路是,对于动量权重更新,您不会直接更改具有负梯度的权重向量。您有一个参数 velocity,您首先将其设置为零,然后将超参数 mu 设置为通常 0.9 .

# Momentum update
v = mu * v - learning_rate * dx # integrate velocity
x += v # integrate position

因此,在以下代码片段中,我的体重 w 和体重变化为 nebla_w

def update_mini_batch(self, mini_batch, eta):
        """Update the network's weights and biases by applying
        gradient descent using backpropagation to a single mini batch.
        The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
        is the learning rate."""
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        for x, y in mini_batch:
            delta_nabla_b, delta_nabla_w = self.backprop(x, y)
            nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
            nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
        self.weights = [w-(eta/len(mini_batch))*nw
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch))*nb
                       for b, nb in zip(self.biases, nabla_b)]

所以在最后两行中,您将 self.weight 更新为

self.weights = [w-(eta/len(mini_batch))*nw
                for w, nw in zip(self.weights, nabla_w)]

对于动量权重更新,我正在执行以下操作:

self.momentum_v = [ (momentum_mu * self.momentum_v) - ( ( float(eta) / float(len(mini_batch)) )* nw) 
                   for nw in nebla_w ]
self.weights = [ w + v 
                for w, v in zip (self.weights, self.momentum_v)]

但是,我收到以下错误:

 TypeError: can't multiply sequence by non-int of type 'float'

用于 momentum_v 更新。我的 eta 超参数已经是浮动的,尽管我再次用浮动函数包装了它。我也用 float 包装了 len(mini_batch) 。我也试过 nw.astype(float) 但我仍然会收到错误。我不确定为什么。 nabla_w 是一个 numpy 的浮点数数组。

正如评论中所讨论的,这里有些东西不是 numpy 数组。上面给出的错误

TypeError: can't multiply sequence by non-int of type 'float'

是 Python 针对序列类型(列表、元组等)发出的错误。错误消息意味着序列不能乘以非整数。它们 可以 乘以一个 int,但这不会改变值——它只是重复序列,即

>>> [1, 0] * 3
[1, 0, 1, 0, 1, 0]

当然在这个框架中,乘以浮点数是没有意义的:

>>> [1, 0] * 3.14
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can't multiply sequence by non-int of type 'float'

您会看到与此处相同的错误消息。因此,您要乘以的变量之一确实不是 numpy 数组,而是通用序列类型之一。围绕有问题的变量进行简单的 np.array() 转换即可修复它,或者当然您可以将定义更改为数组。