Tensorflow 2.0:如何在特征级实现融合网络?

Tensorflow 2.0: How to implement a network with fusion at feature level?

我正在尝试在 tensorflow 中实现一个用于预测任务的小模型,其中两个信号作为输入,分别通过几个层,然后在后面的层中组合以生成输出预测。本质上,该模型是这样工作的:

(Signal A) -> [L 1] -> [L 2] -> ... -> [L k] 
                                            \
                                             \
                                               -> [L k+1] ->...-> [Final Layer] -> Output
                                             /
                                            /
(Signal B) -> [L 1] -> [L 2] -> ... -> [L k]

其中 [L i] 是网络的不同层。在融合之前,网络的第一部分对于两个信号是相同的。在 tensorflow 2.0 中实现这个模型的正确方法是什么?我相信 Sequential 在这种情况下不是一个选项,但是我可以通过 Functional API 来实现还是应该通过模型子类化来实现?从我读到的内容来看,这两种方法似乎没有太大区别。

这是功能性模型的模板API,您可以根据需要更改图层。

您的基本模型(两者通用)-

from tensorflow.keras.layers import Input, Conv1D, Concatenate, MaxPooling1D, Flatten, Dense, GlobalMaxPooling1D, subtract, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l1, l2
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K

# baseline model

input_shape = (256, 1) # assuming your signals have length 256, 1 channel

# conv base model
sig_input = Input(input_shape)
cnn1 = Conv1D(64,3,activation='relu',input_shape=input_shape, kernel_regularizer=l2(2e-4))(sig_input)
mp1 = MaxPooling1D()(cnn1)
mp1 = BatchNormalization()(mp1)
cnn2 = Conv1D(128,3,activation='relu', kernel_regularizer=l2(2e-4))(mp1)
mp2 = MaxPooling1D()(cnn2)
mp2 = BatchNormalization()(mp2)
cnn3 = Conv1D(128,3,activation='relu', kernel_regularizer=l2(2e-4))(mp2)
mp3 = MaxPooling1D()(cnn3)
mp3 = BatchNormalization()(mp3)
cnn4 = Conv1D(256,3,activation='relu', kernel_regularizer=l2(2e-4))(mp3)
mp4 = MaxPooling1D()(cnn4)
mp4 = BatchNormalization()(mp4)
flat = Flatten()(mp4)
embed = Dense(64, activation="sigmoid")(flat)

conv_base = Model(sig_input, embed)

conv_base.summary()

网络摘要:

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_6 (InputLayer)         [(None, 256, 1)]          0         
_________________________________________________________________
conv1d_12 (Conv1D)           (None, 254, 64)           256       
_________________________________________________________________
max_pooling1d_12 (MaxPooling (None, 127, 64)           0         
_________________________________________________________________
batch_normalization_12 (Batc (None, 127, 64)           256       
_________________________________________________________________
conv1d_13 (Conv1D)           (None, 125, 128)          24704     
_________________________________________________________________
max_pooling1d_13 (MaxPooling (None, 62, 128)           0         
_________________________________________________________________
batch_normalization_13 (Batc (None, 62, 128)           512       
_________________________________________________________________
conv1d_14 (Conv1D)           (None, 60, 128)           49280     
_________________________________________________________________
max_pooling1d_14 (MaxPooling (None, 30, 128)           0         
_________________________________________________________________
batch_normalization_14 (Batc (None, 30, 128)           512       
_________________________________________________________________
conv1d_15 (Conv1D)           (None, 28, 256)           98560     
_________________________________________________________________
max_pooling1d_15 (MaxPooling (None, 14, 256)           0         
_________________________________________________________________
batch_normalization_15 (Batc (None, 14, 256)           1024      
_________________________________________________________________
flatten_3 (Flatten)          (None, 3584)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 64)                229440    
=================================================================
Total params: 404,544
Trainable params: 403,392
Non-trainable params: 1,152

第二个融合网络 -


left_input = Input(input_shape)
right_input = Input(input_shape)

# encode each of the two inputs into a vector with the base conv model
encoded_l = conv_base(left_input)
encoded_r = conv_base(right_input)



fusion = Concatenate()([encoded_l,encoded_r]) # this can be any other fusion method too

prediction = Dense(1, activation='sigmoid')(fusion)

twin_net = Model([left_input,right_input],prediction)

optimizer = Adam(0.001)

twin_net.compile(loss="binary_crossentropy",optimizer=optimizer)

twin_net.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_7 (InputLayer)            [(None, 256, 1)]     0                                            
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 256, 1)]     0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 64)           404544      input_7[0][0]                    
                                                                 input_8[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 128)          0           model_2[1][0]                    
                                                                 model_2[2][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1)            129         concatenate[0][0]                
==================================================================================================
Total params: 404,673
Trainable params: 403,521
Non-trainable params: 1,152