在 Tensorflow 中重新训练冻结图 2.x
Retrain Frozen Graph in Tensorflow 2.x
我已经根据 this wonderful detail topic 在 tensorflow 1 中管理了重新训练冻结图的实现。基本上,方法描述如下:
- 加载冻结模型
- 将
constant frozen node
替换为 variable node
。
- 新替换的变量节点将被重定向到冻结节点的相应输出。
通过检查 tf.compat.v1.trainable_variables
这在 tensorflow 1.x 中有效。但是在tensorflow中2.x,就不行了
下面是代码片段:
1/加载冻结模型
frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.graph_util.import_graph_def(od_graph_def, name='')
2/ 创建克隆
with detection_graph.as_default():
const_var_name_pairs = {}
probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
available_names = [op.name for op in detection_graph.get_operations()]
for op in probable_variables:
name = op.name
if name+'/read' not in available_names:
continue
tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
with tf.compat.v1.Session() as s:
tensor_as_numpy_array = s.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
const_var_name_pairs[name] = var_name
3/ 通过图形编辑器替换冻结节点
import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
const_op = name_to_op[const_name+'/read']
var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
writer.close
当我导入 tf.graph_def
而不是具有变量的原始 tf.graph
时,问题出在我的 Graph Editor
上。
通过修正步骤 3 快速解决
Sol1:使用Graph Editor
ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
const_op = ge_graph._node_name_to_node[const_name+'/read']
var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
但是,这需要禁用急切执行。要解决急切执行问题,您应该将 MetaGraphDef
附加到 Graph Editor
,如下所示
with detection_graph.as_default():
meta_saver = tf.compat.v1.train.Saver()
meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))
然而,这是使模型在 tf2.x 中可训练的最棘手的方法
我们应该导出自己,而不是使用 Graph Editor
直接导出图形。原因是 Graph Editor
使变量数据类型成为 resources
。因此,我们应该将图导出为graphdef并将变量def导入图:
test_graph = tf.Graph()
with test_graph.as_default():
tf.import_graph_def(ge_graph.to_graph_def(), name="")
for var_name in ge_graph.variable_names:
var = ge_graph.get_variable_by_name(var_name)
ret = variable_pb2.VariableDef()
ret.variable_name = var._variable_name
ret.initial_value_name = var._initial_value_name
ret.initializer_name = var._initializer_name
ret.snapshot_name = var._snapshot_name
ret.trainable = var._trainable
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
test_graph.add_to_collections(var.collection_names, tf_var)
Sol2:通过 Graphdef 手动映射
with detection_graph.as_default() as graph:
training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
tf.graph_util.import_graph_def(training_graph_def, name='')
for var in current_var:
ret = variable_pb2.VariableDef()
ret.variable_name = var.name
ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
ret.initializer_name = var.name[:-2] + '/Assign'
ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
ret.trainable = True
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
我已经根据 this wonderful detail topic 在 tensorflow 1 中管理了重新训练冻结图的实现。基本上,方法描述如下:
- 加载冻结模型
- 将
constant frozen node
替换为variable node
。 - 新替换的变量节点将被重定向到冻结节点的相应输出。
通过检查 tf.compat.v1.trainable_variables
这在 tensorflow 1.x 中有效。但是在tensorflow中2.x,就不行了
下面是代码片段:
1/加载冻结模型
frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.graph_util.import_graph_def(od_graph_def, name='')
2/ 创建克隆
with detection_graph.as_default():
const_var_name_pairs = {}
probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
available_names = [op.name for op in detection_graph.get_operations()]
for op in probable_variables:
name = op.name
if name+'/read' not in available_names:
continue
tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
with tf.compat.v1.Session() as s:
tensor_as_numpy_array = s.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
const_var_name_pairs[name] = var_name
3/ 通过图形编辑器替换冻结节点
import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
const_op = name_to_op[const_name+'/read']
var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
writer.close
当我导入 tf.graph_def
而不是具有变量的原始 tf.graph
时,问题出在我的 Graph Editor
上。
通过修正步骤 3 快速解决
Sol1:使用Graph Editor
ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
const_op = ge_graph._node_name_to_node[const_name+'/read']
var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
但是,这需要禁用急切执行。要解决急切执行问题,您应该将 MetaGraphDef
附加到 Graph Editor
,如下所示
with detection_graph.as_default():
meta_saver = tf.compat.v1.train.Saver()
meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))
然而,这是使模型在 tf2.x 中可训练的最棘手的方法
我们应该导出自己,而不是使用 Graph Editor
直接导出图形。原因是 Graph Editor
使变量数据类型成为 resources
。因此,我们应该将图导出为graphdef并将变量def导入图:
test_graph = tf.Graph()
with test_graph.as_default():
tf.import_graph_def(ge_graph.to_graph_def(), name="")
for var_name in ge_graph.variable_names:
var = ge_graph.get_variable_by_name(var_name)
ret = variable_pb2.VariableDef()
ret.variable_name = var._variable_name
ret.initial_value_name = var._initial_value_name
ret.initializer_name = var._initializer_name
ret.snapshot_name = var._snapshot_name
ret.trainable = var._trainable
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
test_graph.add_to_collections(var.collection_names, tf_var)
Sol2:通过 Graphdef 手动映射
with detection_graph.as_default() as graph:
training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
tf.graph_util.import_graph_def(training_graph_def, name='')
for var in current_var:
ret = variable_pb2.VariableDef()
ret.variable_name = var.name
ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
ret.initializer_name = var.name[:-2] + '/Assign'
ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
ret.trainable = True
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"