如何在 Keras 中按列拆分张量以实现 STFCN
How to split a tensor column-wise in Keras to implement STFCN
我想在 Keras 中实现时空全卷积网络 (STFCN)。我需要提供 3D 卷积输出的每个深度列,例如形状为 (64, 16, 16)
的张量,作为单独 LSTM 的输入。
为了说明这一点,我有一个 (64 x 16 x 16)
张量,尺寸为 (channels, height, width)
。我需要将张量(显式或隐式)拆分为 16 * 16 = 256 个形状为 (64 x 1 x 1)
.
的张量
这是来自 STFCN 论文的图表,用于说明时空模块。我上面描述的是'Spatial Features'和'Spatio-Temporal Module'之间的箭头。
如何在 Keras 中最好地实现这个想法?
您可以使用来自 Tensorflow 的 tf.split
使用 Keras Lambda
层
使用 Lambda 将形状为 (64,16,16)
的张量拆分为 (64,1,1,256)
,然后对您需要的任何索引进行子集化。
import numpy as np
import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Lambda
# input data
data = np.ones((3,64,16,16))
# define lambda function to split
def lambda_fun(x) :
x = K.expand_dims(x, 4)
split1 = tf.split(x, 16, 2)
x = K.concatenate(split1, 4)
split2 = tf.split(x, 16, 3)
x = K.concatenate(split2, 4)
return x
## check thet splitting works fine
input = Input(shape= (64,16,16))
ll = Lambda(lambda_fun)(input)
model = Model(inputs=input, outputs=ll)
res = model.predict(data)
print(np.shape(res)) #(3, 64, 1, 1, 256)
我想在 Keras 中实现时空全卷积网络 (STFCN)。我需要提供 3D 卷积输出的每个深度列,例如形状为 (64, 16, 16)
的张量,作为单独 LSTM 的输入。
为了说明这一点,我有一个 (64 x 16 x 16)
张量,尺寸为 (channels, height, width)
。我需要将张量(显式或隐式)拆分为 16 * 16 = 256 个形状为 (64 x 1 x 1)
.
这是来自 STFCN 论文的图表,用于说明时空模块。我上面描述的是'Spatial Features'和'Spatio-Temporal Module'之间的箭头。
如何在 Keras 中最好地实现这个想法?
您可以使用来自 Tensorflow 的 tf.split
使用 Keras Lambda
层
使用 Lambda 将形状为 (64,16,16)
的张量拆分为 (64,1,1,256)
,然后对您需要的任何索引进行子集化。
import numpy as np
import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Lambda
# input data
data = np.ones((3,64,16,16))
# define lambda function to split
def lambda_fun(x) :
x = K.expand_dims(x, 4)
split1 = tf.split(x, 16, 2)
x = K.concatenate(split1, 4)
split2 = tf.split(x, 16, 3)
x = K.concatenate(split2, 4)
return x
## check thet splitting works fine
input = Input(shape= (64,16,16))
ll = Lambda(lambda_fun)(input)
model = Model(inputs=input, outputs=ll)
res = model.predict(data)
print(np.shape(res)) #(3, 64, 1, 1, 256)