Keras 的 tensorflow-backend 的复杂矩阵乘法
Complex matrix multiplication with tensorflow-backend of Keras
设矩阵F1
的形状为(a * h * w * m)
,矩阵F2
的形状为(a * h * w * n)
,矩阵G
的形状为(a * m * n)
.
我想实现以下公式,使用 Keras 的 tensorflow 后端从 F1
和 F2
的因子计算 G
的每个因子。但是我对各种后端函数感到困惑,尤其是 K.dot()
和 K.batch_dot()
.
$$ G_{k, i, j} = \sum^h_{s=1} \sum^w_{t=1} \dfrac{F^1_{k, s, t, i} * F^2_{k, s, t, j}}{h * w} $$ 即:
(复制$$内的上述等式粘贴到this site得到的图片)
有什么方法可以实现上面的公式吗?提前谢谢你。
使用 Tensorflow tf.einsum()
(您可以将其包裹在 Keras 的 Lambda
层中):
import tensorflow as tf
import numpy as np
a, h, w, m, n = 1, 2, 3, 4, 5
F1 = tf.random_uniform(shape=(a, h, w, m))
F2 = tf.random_uniform(shape=(a, h, w, n))
G = tf.einsum('ahwm,ahwn->amn', F1, F2) / (h * w)
with tf.Session() as sess:
f1, f2, g = sess.run([F1, F2, G])
# Manually computing G to check our operation, reproducing naively your equation:
g_check = np.zeros(shape=(a, m, n))
for k in range(a):
for i in range(m):
for j in range(n):
for s in range(h):
for t in range(w):
g_check[k, i, j] += f1[k,s,t,i] * f2[k,s,t,j] / (h * w)
# Checking for equality:
print(np.allclose(g, g_check))
# > True
设矩阵F1
的形状为(a * h * w * m)
,矩阵F2
的形状为(a * h * w * n)
,矩阵G
的形状为(a * m * n)
.
我想实现以下公式,使用 Keras 的 tensorflow 后端从 F1
和 F2
的因子计算 G
的每个因子。但是我对各种后端函数感到困惑,尤其是 K.dot()
和 K.batch_dot()
.
$$ G_{k, i, j} = \sum^h_{s=1} \sum^w_{t=1} \dfrac{F^1_{k, s, t, i} * F^2_{k, s, t, j}}{h * w} $$ 即:
(复制$$内的上述等式粘贴到this site得到的图片)
有什么方法可以实现上面的公式吗?提前谢谢你。
使用 Tensorflow tf.einsum()
(您可以将其包裹在 Keras 的 Lambda
层中):
import tensorflow as tf
import numpy as np
a, h, w, m, n = 1, 2, 3, 4, 5
F1 = tf.random_uniform(shape=(a, h, w, m))
F2 = tf.random_uniform(shape=(a, h, w, n))
G = tf.einsum('ahwm,ahwn->amn', F1, F2) / (h * w)
with tf.Session() as sess:
f1, f2, g = sess.run([F1, F2, G])
# Manually computing G to check our operation, reproducing naively your equation:
g_check = np.zeros(shape=(a, m, n))
for k in range(a):
for i in range(m):
for j in range(n):
for s in range(h):
for t in range(w):
g_check[k, i, j] += f1[k,s,t,i] * f2[k,s,t,j] / (h * w)
# Checking for equality:
print(np.allclose(g, g_check))
# > True