如何将不同的致密层应用于keras中矩阵的每一行
How to apply a different dense layer to each row of a matrix in keras
我上一层的输出具有形状 (None, 30, 600)。我想将该矩阵的每一行乘以 different (600, 600) 矩阵或等效地将该矩阵乘以 3D 权重矩阵。这可以通过对每一行应用不同的致密层来实现。我尝试使用 TimeDistributed Wrapper,但将 same 密集层应用于每一行。我也试过像这样使用 lambda 层:
Lambda(lambda x: tf.stack(x, axis=1))(
Lambda(lambda x: [Dense(600)(each) for each in tf.unstack(x, axis=1)])(prev_layer_output)
)
这似乎解决了问题,我能够正确地训练模型。但我注意到 model.summary() 不识别这些密集层,它们也没有反映在可训练参数总数中。此外,我无法在加载模型时恢复它们的权重,因此整个训练都被浪费了。我该如何解决这个问题?如何对矩阵的每一行应用不同的致密层?
您可以按元素将 (30,600) 矩阵与 (600,30,600) 相乘,这样您将获得 (600,30,600),如果您然后对最后一个维度求和,您应该得到您想要的转置。我在 numpy 中测试了这个,而不是在 tensorflow 中,但它应该是相同的
您可以使用多个图层,而不是将所有内容都包装到一个 Lambda
图层中。
x = Input((30, 600))
unstacked = Lambda(lambda x: K.tf.unstack(x, axis=1))(x)
dense_outputs = [Dense(600)(x) for x in unstacked]
merged = Lambda(lambda x: K.stack(x, axis=1))(dense_outputs)
model = Model(x, merged)
现在您可以在 model.summary()
中看到 30 Dense(600)
个图层。
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 30, 600) 0
__________________________________________________________________________________________________
lambda_1 (Lambda) [(None, 600), (None, 0 input_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 600) 360600 lambda_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 600) 360600 lambda_1[0][1]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 600) 360600 lambda_1[0][2]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 600) 360600 lambda_1[0][3]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 600) 360600 lambda_1[0][4]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 600) 360600 lambda_1[0][5]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 600) 360600 lambda_1[0][6]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 600) 360600 lambda_1[0][7]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 600) 360600 lambda_1[0][8]
__________________________________________________________________________________________________
dense_10 (Dense) (None, 600) 360600 lambda_1[0][9]
__________________________________________________________________________________________________
dense_11 (Dense) (None, 600) 360600 lambda_1[0][10]
__________________________________________________________________________________________________
dense_12 (Dense) (None, 600) 360600 lambda_1[0][11]
__________________________________________________________________________________________________
dense_13 (Dense) (None, 600) 360600 lambda_1[0][12]
__________________________________________________________________________________________________
dense_14 (Dense) (None, 600) 360600 lambda_1[0][13]
__________________________________________________________________________________________________
dense_15 (Dense) (None, 600) 360600 lambda_1[0][14]
__________________________________________________________________________________________________
dense_16 (Dense) (None, 600) 360600 lambda_1[0][15]
__________________________________________________________________________________________________
dense_17 (Dense) (None, 600) 360600 lambda_1[0][16]
__________________________________________________________________________________________________
dense_18 (Dense) (None, 600) 360600 lambda_1[0][17]
__________________________________________________________________________________________________
dense_19 (Dense) (None, 600) 360600 lambda_1[0][18]
__________________________________________________________________________________________________
dense_20 (Dense) (None, 600) 360600 lambda_1[0][19]
__________________________________________________________________________________________________
dense_21 (Dense) (None, 600) 360600 lambda_1[0][20]
__________________________________________________________________________________________________
dense_22 (Dense) (None, 600) 360600 lambda_1[0][21]
__________________________________________________________________________________________________
dense_23 (Dense) (None, 600) 360600 lambda_1[0][22]
__________________________________________________________________________________________________
dense_24 (Dense) (None, 600) 360600 lambda_1[0][23]
__________________________________________________________________________________________________
dense_25 (Dense) (None, 600) 360600 lambda_1[0][24]
__________________________________________________________________________________________________
dense_26 (Dense) (None, 600) 360600 lambda_1[0][25]
__________________________________________________________________________________________________
dense_27 (Dense) (None, 600) 360600 lambda_1[0][26]
__________________________________________________________________________________________________
dense_28 (Dense) (None, 600) 360600 lambda_1[0][27]
__________________________________________________________________________________________________
dense_29 (Dense) (None, 600) 360600 lambda_1[0][28]
__________________________________________________________________________________________________
dense_30 (Dense) (None, 600) 360600 lambda_1[0][29]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 30, 600) 0 dense_1[0][0]
dense_2[0][0]
dense_3[0][0]
dense_4[0][0]
dense_5[0][0]
dense_6[0][0]
dense_7[0][0]
dense_8[0][0]
dense_9[0][0]
dense_10[0][0]
dense_11[0][0]
dense_12[0][0]
dense_13[0][0]
dense_14[0][0]
dense_15[0][0]
dense_16[0][0]
dense_17[0][0]
dense_18[0][0]
dense_19[0][0]
dense_20[0][0]
dense_21[0][0]
dense_22[0][0]
dense_23[0][0]
dense_24[0][0]
dense_25[0][0]
dense_26[0][0]
dense_27[0][0]
dense_28[0][0]
dense_29[0][0]
dense_30[0][0]
==================================================================================================
Total params: 10,818,000
Trainable params: 10,818,000
Non-trainable params: 0
__________________________________________________________________________________________________
编辑:验证该模型是否正在学习:
model.compile(loss='mse', optimizer='adam')
w0 = model.get_weights()
model.fit(np.random.rand(100,30,600), np.random.rand(100,30,600), epochs=10)
你应该能看到损失在减少:
Epoch 1/10
100/100 [==============================] - 1s 15ms/step - loss: 0.4725
Epoch 2/10
100/100 [==============================] - 0s 1ms/step - loss: 0.2211
Epoch 3/10
100/100 [==============================] - 0s 1ms/step - loss: 0.2405
Epoch 4/10
100/100 [==============================] - 0s 1ms/step - loss: 0.2013
Epoch 5/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1771
Epoch 6/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1676
Epoch 7/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1568
Epoch 8/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1473
Epoch 9/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1400
Epoch 10/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1343
此外,您可以通过比较模型拟合前后的值来验证权重确实得到了更新:
w0 = model.get_weights()
model.fit(np.random.rand(100,30,600), np.random.rand(100,30,600), epochs=10)
w1 = model.get_weights()
print(not any(np.allclose(x0, x1) for x0, x1 in zip(w0, w1)))
# => True
我上一层的输出具有形状 (None, 30, 600)。我想将该矩阵的每一行乘以 different (600, 600) 矩阵或等效地将该矩阵乘以 3D 权重矩阵。这可以通过对每一行应用不同的致密层来实现。我尝试使用 TimeDistributed Wrapper,但将 same 密集层应用于每一行。我也试过像这样使用 lambda 层:
Lambda(lambda x: tf.stack(x, axis=1))(
Lambda(lambda x: [Dense(600)(each) for each in tf.unstack(x, axis=1)])(prev_layer_output)
)
这似乎解决了问题,我能够正确地训练模型。但我注意到 model.summary() 不识别这些密集层,它们也没有反映在可训练参数总数中。此外,我无法在加载模型时恢复它们的权重,因此整个训练都被浪费了。我该如何解决这个问题?如何对矩阵的每一行应用不同的致密层?
您可以按元素将 (30,600) 矩阵与 (600,30,600) 相乘,这样您将获得 (600,30,600),如果您然后对最后一个维度求和,您应该得到您想要的转置。我在 numpy 中测试了这个,而不是在 tensorflow 中,但它应该是相同的
您可以使用多个图层,而不是将所有内容都包装到一个 Lambda
图层中。
x = Input((30, 600))
unstacked = Lambda(lambda x: K.tf.unstack(x, axis=1))(x)
dense_outputs = [Dense(600)(x) for x in unstacked]
merged = Lambda(lambda x: K.stack(x, axis=1))(dense_outputs)
model = Model(x, merged)
现在您可以在 model.summary()
中看到 30 Dense(600)
个图层。
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 30, 600) 0
__________________________________________________________________________________________________
lambda_1 (Lambda) [(None, 600), (None, 0 input_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 600) 360600 lambda_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 600) 360600 lambda_1[0][1]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 600) 360600 lambda_1[0][2]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 600) 360600 lambda_1[0][3]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 600) 360600 lambda_1[0][4]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 600) 360600 lambda_1[0][5]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 600) 360600 lambda_1[0][6]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 600) 360600 lambda_1[0][7]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 600) 360600 lambda_1[0][8]
__________________________________________________________________________________________________
dense_10 (Dense) (None, 600) 360600 lambda_1[0][9]
__________________________________________________________________________________________________
dense_11 (Dense) (None, 600) 360600 lambda_1[0][10]
__________________________________________________________________________________________________
dense_12 (Dense) (None, 600) 360600 lambda_1[0][11]
__________________________________________________________________________________________________
dense_13 (Dense) (None, 600) 360600 lambda_1[0][12]
__________________________________________________________________________________________________
dense_14 (Dense) (None, 600) 360600 lambda_1[0][13]
__________________________________________________________________________________________________
dense_15 (Dense) (None, 600) 360600 lambda_1[0][14]
__________________________________________________________________________________________________
dense_16 (Dense) (None, 600) 360600 lambda_1[0][15]
__________________________________________________________________________________________________
dense_17 (Dense) (None, 600) 360600 lambda_1[0][16]
__________________________________________________________________________________________________
dense_18 (Dense) (None, 600) 360600 lambda_1[0][17]
__________________________________________________________________________________________________
dense_19 (Dense) (None, 600) 360600 lambda_1[0][18]
__________________________________________________________________________________________________
dense_20 (Dense) (None, 600) 360600 lambda_1[0][19]
__________________________________________________________________________________________________
dense_21 (Dense) (None, 600) 360600 lambda_1[0][20]
__________________________________________________________________________________________________
dense_22 (Dense) (None, 600) 360600 lambda_1[0][21]
__________________________________________________________________________________________________
dense_23 (Dense) (None, 600) 360600 lambda_1[0][22]
__________________________________________________________________________________________________
dense_24 (Dense) (None, 600) 360600 lambda_1[0][23]
__________________________________________________________________________________________________
dense_25 (Dense) (None, 600) 360600 lambda_1[0][24]
__________________________________________________________________________________________________
dense_26 (Dense) (None, 600) 360600 lambda_1[0][25]
__________________________________________________________________________________________________
dense_27 (Dense) (None, 600) 360600 lambda_1[0][26]
__________________________________________________________________________________________________
dense_28 (Dense) (None, 600) 360600 lambda_1[0][27]
__________________________________________________________________________________________________
dense_29 (Dense) (None, 600) 360600 lambda_1[0][28]
__________________________________________________________________________________________________
dense_30 (Dense) (None, 600) 360600 lambda_1[0][29]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 30, 600) 0 dense_1[0][0]
dense_2[0][0]
dense_3[0][0]
dense_4[0][0]
dense_5[0][0]
dense_6[0][0]
dense_7[0][0]
dense_8[0][0]
dense_9[0][0]
dense_10[0][0]
dense_11[0][0]
dense_12[0][0]
dense_13[0][0]
dense_14[0][0]
dense_15[0][0]
dense_16[0][0]
dense_17[0][0]
dense_18[0][0]
dense_19[0][0]
dense_20[0][0]
dense_21[0][0]
dense_22[0][0]
dense_23[0][0]
dense_24[0][0]
dense_25[0][0]
dense_26[0][0]
dense_27[0][0]
dense_28[0][0]
dense_29[0][0]
dense_30[0][0]
==================================================================================================
Total params: 10,818,000
Trainable params: 10,818,000
Non-trainable params: 0
__________________________________________________________________________________________________
编辑:验证该模型是否正在学习:
model.compile(loss='mse', optimizer='adam')
w0 = model.get_weights()
model.fit(np.random.rand(100,30,600), np.random.rand(100,30,600), epochs=10)
你应该能看到损失在减少:
Epoch 1/10
100/100 [==============================] - 1s 15ms/step - loss: 0.4725
Epoch 2/10
100/100 [==============================] - 0s 1ms/step - loss: 0.2211
Epoch 3/10
100/100 [==============================] - 0s 1ms/step - loss: 0.2405
Epoch 4/10
100/100 [==============================] - 0s 1ms/step - loss: 0.2013
Epoch 5/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1771
Epoch 6/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1676
Epoch 7/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1568
Epoch 8/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1473
Epoch 9/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1400
Epoch 10/10
100/100 [==============================] - 0s 1ms/step - loss: 0.1343
此外,您可以通过比较模型拟合前后的值来验证权重确实得到了更新:
w0 = model.get_weights()
model.fit(np.random.rand(100,30,600), np.random.rand(100,30,600), epochs=10)
w1 = model.get_weights()
print(not any(np.allclose(x0, x1) for x0, x1 in zip(w0, w1)))
# => True