tf.keras.layers.Dense - 参数数量?

tf.keras.layers.Dense - number of parameters?

我一直在使用 keras functional API 来构建一个不错的网络。但是,我不明白 tf.keras.layers.Dense 中的空间连通性是如何工作的。

如果我展平一个 7x7x1024 的体积,我会得到 50,176 个参数。我希望两层之间的参数总数为

50,176 * 4096 + 4096 = 205,524,992

是的。

令人惊讶的是,当我删除层 Flatten() 时,我没有收到任何维度不兼容错误。输出shape为7x7x4096,参数个数为:

1024*4096 + 4096 = 4,198,400

如果这是正确的,为什么 tf.keras.layers.Dense 层的最后维度之间只有密集连接,为什么输出是 7x7x4096 体积?

(last layer is 7 x 7 x 1024 volume) 
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(4096)(x)

当您传递 dim>2 的张量时,Dense 会创建与最后一个维度的连接作为默认行为 [1](第 889 行,input_dim = input_shape[-1]),即为什么你没有得到任何错误。结果你也得到了你已经计算的参数数量。 因此,如果您使用的是 3D 输入,则需要在将其传递到 Dense 层之前将其展平。

[1] https://github.com/keras-team/keras/blob/master/keras/layers/core.py#L796