检查tensorflow keras模型中的下一层

Check which are the next layers in a tensorflow keras model

我有一个 模型,它在层之间有快捷方式。对于每一层,我想获得下一个连接层的名称(或索引),因为简单地遍历所有 model.layers 不会告诉我该层是否连接到前一个层。

示例模型可以是:

model = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)

您可以查看整个模型及其与 keras 的连接 Model plotting utilities

tf.keras.utils.plot_model(model, to_file='path/to/image', show_shapes=True)

这样可以提取dict格式的信息...

首先,定义一个效用函数并从每个Functional模型(code reference

中获取model.summary()方法中的相关节点
relevant_nodes = []
for v in model._nodes_by_depth.values():
    relevant_nodes += v

def get_layer_summary_with_connections(layer):
    
    info = {}
    connections = []
    for node in layer._inbound_nodes:
        if relevant_nodes and node not in relevant_nodes:
            # node is not part of the current network
            continue

        for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
            connections.append(inbound_layer.name)
            
    name = layer.name
    info['type'] = layer.__class__.__name__
    info['parents'] = connections
            
    return info

其次,层层迭代提取信息:

results = {}
layers = model.layers
for layer in layers:
    info = get_layer_summary_with_connections(layer)
    results[layer.name] = info

results 是嵌套的 dict,格式如下:

{
  'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'},
  ...
  'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'}
}

对于 ResNet50 它导致:

{
  'input_4': {'type': 'InputLayer', 'parents': []},
  'conv1_pad': {'type': 'ZeroPadding2D', 'parents': ['input_4']},
  'conv1_conv': {'type': 'Conv2D', 'parents': ['conv1_pad']},
  'conv1_bn': {'type': 'BatchNormalization', 'parents': ['conv1_conv']},
  ...
  'conv5_block3_out': {'type': 'Activation', 'parents': ['conv5_block3_add']},
  'avg_pool': {'type': 'GlobalAveragePooling2D', 'parents' ['conv5_block3_out']},
  'predictions': {'type': 'Dense', 'parents': ['avg_pool']}
}

另外,您可以将get_layer_summary_with_connections修改为return所有您感兴趣的信息