将冻结值从冻结图复制到另一个冻结图

Copy Frozen Values From A Frozen Graph to Another Frozen Graph

我有 2 个 frozen_graphs,它们被训练并存储为不同的 pb 文件。它们都共享一些相同的节点。如何将节点值从一张图转移到另一张图?例如,如何复制 FakeQuantWithMinMaxVars 节点以替换以下节点?

我已经通过在图中相似的节点映射节点来解决这个问题。然后通过 tf.import_graph_def 连接它并通过 graph_transform 删除 unused_nodes。对于量化功能,避免使用合并重复或折叠批量规范,这将通过缺少 min-max quant

来产生量化错误
import tensorflow as tf
import numpy as np

# load graphs using pb file path
def load_graph(pb_file):
    graph = tf.Graph()
    with graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(pb_file, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return graph

resnet_pretrained = 'frozen_124.pb'
trained = 'frozen.pb'

# new file name to save combined model
final_graph = 'final_graph.pb'

# loads both graphs
graph1 = load_graph(resnet_pretrained)
graph2 = load_graph(trained)
replace_dict = {}
# get tensor names from first graph
with graph1.as_default():

    # getting tensors to add crop and resize step
    ops = graph1.get_operations()
    ops1_name = []
    for op in ops:
        # print(op.name)
        ops1_name.append(op.name)
    ops = graph2.get_operations()
    ops2_name = []
    replace_name = []
    for op in ops:
        # print(op.name)
        ops2_name.append(op.name)
        if op.name in ops1_name:
            replace_name = op.name
            replace_dict[str(replace_name)+':0'] = replace_name+':0'
            continue
        if 'resnet' in op.name:
            replace_name = op.name.replace("resnet","model")
            if replace_name in ops1_name:
                replace_dict[str(op.name)+':0'] = replace_name+':0'

with tf.Graph().as_default() as final:
    y = tf.import_graph_def(graph1.as_graph_def(), return_elements=replace_dict.values())
    new=dict()
    for i,j in zip(replace_dict.keys(),y):
        new[i] = j
    z = tf.import_graph_def(graph2.as_graph_def(), input_map=new, return_elements=["concatenate_1/concat:0"])

    # tf.train.write_graph(graph2.as_graph_def(), "./", final_graph, as_text=False)

# for op in final.get_operations():
#     print(op.name)
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ['remove_nodes(op=Identity)',
 'strip_unused_nodes']
output_graph_def = TransformGraph(
        final.as_graph_def(),
        ["import/input_image","import_1/input_box"], ## input
        ["import_1/concatenate_1/concat"], ## outputs
        transforms)
tf.train.write_graph(output_graph_def, '.' , as_text=False, name='optimized_model.pb')
print('Graph optimized!')