在 Tensorflow 中重新训练冻结图 2.x

Retrain Frozen Graph in Tensorflow 2.x

我已经根据 this wonderful detail topic 在 tensorflow 1 中管理了重新训练冻结图的实现。基本上,方法描述如下:

  1. 加载冻结模型
  2. constant frozen node 替换为 variable node
  3. 新替换的变量节点将被重定向到冻结节点的相应输出。

通过检查 tf.compat.v1.trainable_variables 这在 tensorflow 1.x 中有效。但是在tensorflow中2.x,就不行了

下面是代码片段:

1/加载冻结模型

frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.graph_util.import_graph_def(od_graph_def, name='')

2/ 创建克隆

with detection_graph.as_default():
    const_var_name_pairs = {}
    probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
    available_names = [op.name for op in detection_graph.get_operations()]
    for op in probable_variables:
        name = op.name
        if name+'/read' not in available_names:
            continue
        tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
        with tf.compat.v1.Session() as s:
            tensor_as_numpy_array = s.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
        const_var_name_pairs[name] =  var_name

3/ 通过图形编辑器替换冻结节点

import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
    const_op = name_to_op[const_name+'/read']
    var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
    writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
    writer.close

当我导入 tf.graph_def 而不是具有变量的原始 tf.graph 时,问题出在我的 Graph Editor 上。

通过修正步骤 3 快速解决

Sol1:使用Graph Editor

ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
    const_op = ge_graph._node_name_to_node[const_name+'/read']
    var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

但是,这需要禁用急切执行。要解决急切执行问题,您应该将 MetaGraphDef 附加到 Graph Editor,如下所示

with detection_graph.as_default():
    meta_saver = tf.compat.v1.train.Saver()
    meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))

然而,这是使模型在 tf2.x 中可训练的最棘手的方法 我们应该导出自己,而不是使用 Graph Editor 直接导出图形。原因是 Graph Editor 使变量数据类型成为 resources。因此,我们应该将图导出为graphdef并将变量def导入图:

test_graph = tf.Graph()
with test_graph.as_default():
    tf.import_graph_def(ge_graph.to_graph_def(), name="")
    for var_name in ge_graph.variable_names:
        var = ge_graph.get_variable_by_name(var_name)
        ret = variable_pb2.VariableDef()
        ret.variable_name = var._variable_name
        ret.initial_value_name = var._initial_value_name
        ret.initializer_name = var._initializer_name
        ret.snapshot_name = var._snapshot_name
        ret.trainable = var._trainable
        ret.is_resource = True
        tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
        test_graph.add_to_collections(var.collection_names, tf_var)

Sol2:通过 Graphdef 手动映射

with detection_graph.as_default() as graph:
    training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
    current_var = (tf.compat.v1.trainable_variables())
    assert len(current_var)>0, "no training variables"


detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
    tf.graph_util.import_graph_def(training_graph_def, name='')
    for var in current_var:
        ret = variable_pb2.VariableDef()
        ret.variable_name = var.name
        ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
        ret.initializer_name = var.name[:-2] + '/Assign'
        ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
        ret.trainable = True
        ret.is_resource = True
        tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
        detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
    current_var = (tf.compat.v1.trainable_variables())
    assert len(current_var)>0, "no training variables"