TF 对象检测 Zoo 模型没有可训练变量?
TF Objection Detection Zoo models don't have Trainable Variables?
TF Objection Detection Zoo中的模型有meta+ckpt文件,Frozen.pb文件,Saved_model文件。
我尝试使用 meta+ckpt 文件进一步训练,并为研究目的提取特定张量的一些权重。我看到模型没有任何可训练的变量。
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)
上面的代码片段给出了一个 []
列表。我也尝试使用以下方法。
vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)
我又得到了一个 []
列表。
这怎么可能?模型是否剥离了变量?还是 tf.Variable(trainable=False)
?我在哪里可以获得具有有效可训练变量的 meta+ckpt 文件。我专门看SSD+mobilnet机型
更新:
以下是我在 class 中用于 restoring.It 的代码片段,因为我正在为某些应用程序制作自定义工具。
def _importer(self):
sess = tf.InteractiveSession()
with sess.as_default():
reader = tf.train.import_meta_graph(self.metafile,
clear_devices=True)
reader.restore(sess, self.ckptfile)
def _read_graph(self):
sess = tf.get_default_session()
with sess.as_default():
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)
更新 2:
我还尝试了以下代码片段。简约复古风
model_dir = 'ssd_mobilenet_v2/'
meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()
sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
with tf.Session() as sess:
reader = tf.train.import_meta_graph(meta,clear_devices=True)
reader.restore(sess,ckpt)
vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for var in vari:
print(var.name,"\n")
上面的代码片段还给出了 []
变量列表
经过一些研究,您问题的最终答案是不,他们没有。很明显,直到您意识到 saved_model
中的 variables
目录是空的。
对象检测模型动物园提供的检查点文件包含以下文件:
.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
|-- saved_model.pb
`-- variables
pipeline.config
是保存模型的配置文件,frozen_inference_graph.pb
是off-the-shelf推理。请注意,checkpoint
、model.ckpt.data-00000-of-00001
、model.ckpt.meta
和 model.ckpt.index
都对应于检查点 。 ( 你可以找到一个很好的解释)
所以当你想得到可训练的变量时,唯一有用的就是saved_model
目录。
Use SavedModel to save and load your model—variables, the graph, and the graph's metadata. This is a language-neutral, recoverable, hermetic serialization format that enables higher-level systems and tools to produce, consume, and transform TensorFlow models.
要恢复 SavedModel
你可以使用 api tf.saved_model.loader.load()
,这个 api 包含一个名为 tags
的参数,它指定MetaGraphDef
。所以如果你想得到可训练的变量,你需要在调用api.
时指定tag_constants.TRAINING
我试图调用此 api 来恢复变量,但它给了我一个错误
MetaGraphDef associated with tags 'train' could not be found in SavedModel. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: saved_model_cli
所以我执行了这个 saved_model_cli
命令来检查 SavedModel
.
中可用的所有标签
#from directory saved_model
saved_model_cli show --dir . --all
输出为
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
...
因此 SavedModel
中没有标签 train
,只有 serve
。因此,此处的 SavedModel
仅用于 tensorflow 服务。这意味着当这些文件在创建时未使用标记 training
指定时,无法从这些文件中恢复训练变量。
P.S.: 以下代码是我用来恢复 SavedModel
的代码。设置tag_constants.TRAINING
时无法加载完成,设置tag_constants.SERVING
时加载成功但变量为空
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(variables)
P.P.S: 我找到了创建 SavedModel
here 的脚本。可见创建SavedModel
.
时确实没有train
标签
TF Objection Detection Zoo中的模型有meta+ckpt文件,Frozen.pb文件,Saved_model文件。
我尝试使用 meta+ckpt 文件进一步训练,并为研究目的提取特定张量的一些权重。我看到模型没有任何可训练的变量。
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)
上面的代码片段给出了一个 []
列表。我也尝试使用以下方法。
vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)
我又得到了一个 []
列表。
这怎么可能?模型是否剥离了变量?还是 tf.Variable(trainable=False)
?我在哪里可以获得具有有效可训练变量的 meta+ckpt 文件。我专门看SSD+mobilnet机型
更新:
以下是我在 class 中用于 restoring.It 的代码片段,因为我正在为某些应用程序制作自定义工具。
def _importer(self):
sess = tf.InteractiveSession()
with sess.as_default():
reader = tf.train.import_meta_graph(self.metafile,
clear_devices=True)
reader.restore(sess, self.ckptfile)
def _read_graph(self):
sess = tf.get_default_session()
with sess.as_default():
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)
更新 2:
我还尝试了以下代码片段。简约复古风
model_dir = 'ssd_mobilenet_v2/'
meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()
sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
with tf.Session() as sess:
reader = tf.train.import_meta_graph(meta,clear_devices=True)
reader.restore(sess,ckpt)
vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for var in vari:
print(var.name,"\n")
上面的代码片段还给出了 []
变量列表
经过一些研究,您问题的最终答案是不,他们没有。很明显,直到您意识到 saved_model
中的 variables
目录是空的。
对象检测模型动物园提供的检查点文件包含以下文件:
.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
|-- saved_model.pb
`-- variables
pipeline.config
是保存模型的配置文件,frozen_inference_graph.pb
是off-the-shelf推理。请注意,checkpoint
、model.ckpt.data-00000-of-00001
、model.ckpt.meta
和 model.ckpt.index
都对应于检查点 。 (
所以当你想得到可训练的变量时,唯一有用的就是saved_model
目录。
Use SavedModel to save and load your model—variables, the graph, and the graph's metadata. This is a language-neutral, recoverable, hermetic serialization format that enables higher-level systems and tools to produce, consume, and transform TensorFlow models.
要恢复 SavedModel
你可以使用 api tf.saved_model.loader.load()
,这个 api 包含一个名为 tags
的参数,它指定MetaGraphDef
。所以如果你想得到可训练的变量,你需要在调用api.
tag_constants.TRAINING
我试图调用此 api 来恢复变量,但它给了我一个错误
MetaGraphDef associated with tags 'train' could not be found in SavedModel. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI:
saved_model_cli
所以我执行了这个 saved_model_cli
命令来检查 SavedModel
.
#from directory saved_model
saved_model_cli show --dir . --all
输出为
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
...
因此 SavedModel
中没有标签 train
,只有 serve
。因此,此处的 SavedModel
仅用于 tensorflow 服务。这意味着当这些文件在创建时未使用标记 training
指定时,无法从这些文件中恢复训练变量。
P.S.: 以下代码是我用来恢复 SavedModel
的代码。设置tag_constants.TRAINING
时无法加载完成,设置tag_constants.SERVING
时加载成功但变量为空
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(variables)
P.P.S: 我找到了创建 SavedModel
here 的脚本。可见创建SavedModel
.
train
标签