用于堆叠两个 CNN 的适配模块设计

Adaptation module design for stacking two CNNs

我正在尝试堆叠两个不同的CNN使用适应模块桥接它们,但我很难正确确定适配模块的层超参数。

更准确地说,我想训练自适应模块来桥接两个卷积层:

  1. A 层输出形状:(29,29,256)
  2. B 层输入形状:(8,8,384)

所以,在Layer A之后,我依次添加适配模块,我选择的是:

最后,我尝试将 B 层添加到模型中,但我从 tensorflow 得到以下 错误

InvalidArgumentError: Dimensions must be equal, but are 384 and 288 for '{{node batch_normalization_159/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization_159/scale, batch_normalization_159/ReadVariableOp, batch_normalization_159/FusedBatchNormV3/ReadVariableOp, batch_normalization_159/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [?,8,8,384], [288], [288], [288], [288].

有一个最小的可重现示例:

from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.mobilenet import MobileNet
from keras.layers import Conv2D, MaxPool2D
from keras.models import Sequential

mobile_model = MobileNet(weights='imagenet')
server_model = InceptionResNetV2(weights='imagenet')

hybrid = Sequential()

for i, layer in enumerate(mobile_model.layers):
  if i <= 36:
    layer.trainable = False
    hybrid.add(layer)

hybrid.add(Conv2D(384, kernel_size=(3,3), padding='same'))
hybrid.add(MaxPool2D(pool_size=(2,2), strides=(4,4), padding='same'))

for i, layer in enumerate(server_model.layers):
  if i >= 610:
    layer.trainable = False
    hybrid.add(layer)

顺序模型只支持层像链表一样排列的模型——每一层只接受一层的输出,每一层的输出只提供给一个层。你的两个基础模型都有残差块,打破了上面的假设,把模型架构变成了有向无环图(DAG)。

要完成您想做的事情,您需要使用函数 API。使用 Functional API,您可以显式控制中间激活,也就是 KerasTensors。

对于第一个模型,您可以跳过这些额外的工作,只需像这样从现有图的子集创建一个新模型

sub_mobile = keras.models.Model(mobile_model.inputs, mobile_model.layers[36].output)

为第二个模型的某些层布线要困难得多。切掉 keras 模型的末尾很容易 - 由于需要 tf.keras.Input 占位符,因此切开开头要困难得多。要成功做到这一点,您需要编写一个遍历各层的模型行走算法,跟踪输出 KerasTensors,然后使用新输入调用每个层以创建新的输出 KerasTensor。

您可以通过简单地找到 InceptionResNet 的一些源代码并通过 Python 添加层而不是内省现有模型来避免所有这些工作。这是一个可能符合要求的。

https://github.com/yuyang-huang/keras-inception-resnet-v2/blob/master/inception_resnet_v2.py