检查tensorflow keras模型中的下一层
Check which are the next layers in a tensorflow keras model
我有一个 keras 模型,它在层之间有快捷方式。对于每一层,我想获得下一个连接层的名称(或索引),因为简单地遍历所有 model.layers
不会告诉我该层是否连接到前一个层。
示例模型可以是:
model = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)
您可以查看整个模型及其与 keras 的连接 Model plotting utilities
tf.keras.utils.plot_model(model, to_file='path/to/image', show_shapes=True)
这样可以提取dict
格式的信息...
首先,定义一个效用函数并从每个Functional
模型(code reference)
中获取model.summary()
方法中的相关节点
relevant_nodes = []
for v in model._nodes_by_depth.values():
relevant_nodes += v
def get_layer_summary_with_connections(layer):
info = {}
connections = []
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
connections.append(inbound_layer.name)
name = layer.name
info['type'] = layer.__class__.__name__
info['parents'] = connections
return info
其次,层层迭代提取信息:
results = {}
layers = model.layers
for layer in layers:
info = get_layer_summary_with_connections(layer)
results[layer.name] = info
results
是嵌套的 dict
,格式如下:
{
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'},
...
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'}
}
对于 ResNet50
它导致:
{
'input_4': {'type': 'InputLayer', 'parents': []},
'conv1_pad': {'type': 'ZeroPadding2D', 'parents': ['input_4']},
'conv1_conv': {'type': 'Conv2D', 'parents': ['conv1_pad']},
'conv1_bn': {'type': 'BatchNormalization', 'parents': ['conv1_conv']},
...
'conv5_block3_out': {'type': 'Activation', 'parents': ['conv5_block3_add']},
'avg_pool': {'type': 'GlobalAveragePooling2D', 'parents' ['conv5_block3_out']},
'predictions': {'type': 'Dense', 'parents': ['avg_pool']}
}
另外,您可以将get_layer_summary_with_connections
修改为return所有您感兴趣的信息
我有一个 keras 模型,它在层之间有快捷方式。对于每一层,我想获得下一个连接层的名称(或索引),因为简单地遍历所有 model.layers
不会告诉我该层是否连接到前一个层。
示例模型可以是:
model = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)
您可以查看整个模型及其与 keras 的连接 Model plotting utilities
tf.keras.utils.plot_model(model, to_file='path/to/image', show_shapes=True)
这样可以提取dict
格式的信息...
首先,定义一个效用函数并从每个Functional
模型(code reference)
model.summary()
方法中的相关节点
relevant_nodes = []
for v in model._nodes_by_depth.values():
relevant_nodes += v
def get_layer_summary_with_connections(layer):
info = {}
connections = []
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
connections.append(inbound_layer.name)
name = layer.name
info['type'] = layer.__class__.__name__
info['parents'] = connections
return info
其次,层层迭代提取信息:
results = {}
layers = model.layers
for layer in layers:
info = get_layer_summary_with_connections(layer)
results[layer.name] = info
results
是嵌套的 dict
,格式如下:
{
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'},
...
'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'}
}
对于 ResNet50
它导致:
{
'input_4': {'type': 'InputLayer', 'parents': []},
'conv1_pad': {'type': 'ZeroPadding2D', 'parents': ['input_4']},
'conv1_conv': {'type': 'Conv2D', 'parents': ['conv1_pad']},
'conv1_bn': {'type': 'BatchNormalization', 'parents': ['conv1_conv']},
...
'conv5_block3_out': {'type': 'Activation', 'parents': ['conv5_block3_add']},
'avg_pool': {'type': 'GlobalAveragePooling2D', 'parents' ['conv5_block3_out']},
'predictions': {'type': 'Dense', 'parents': ['avg_pool']}
}
另外,您可以将get_layer_summary_with_connections
修改为return所有您感兴趣的信息