TF 对象检测 Zoo 模型没有可训练变量?

TF Objection Detection Zoo models don't have Trainable Variables?

TF Objection Detection Zoo中的模型有meta+ckpt文件,Frozen.pb文件,Saved_model文件。

我尝试使用 meta+ckpt 文件进一步训练,并为研究目的提取特定张量的一些权重。我看到模型没有任何可训练的变量。

vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)

上面的代码片段给出了一个 [] 列表。我也尝试使用以下方法。

vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)

我又得到了一个 [] 列表。

这怎么可能?模型是否剥离了变量?还是 tf.Variable(trainable=False) ?我在哪里可以获得具有有效可训练变量的 meta+ckpt 文件。我专门看SSD+mobilnet机型

更新:

以下是我在 class 中用于 restoring.It 的代码片段,因为我正在为某些应用程序制作自定义工具。

def _importer(self):
    sess = tf.InteractiveSession()
    with sess.as_default():
        reader = tf.train.import_meta_graph(self.metafile,
                                            clear_devices=True)
        reader.restore(sess, self.ckptfile)

def _read_graph(self):
    sess = tf.get_default_session()
    with sess.as_default():
        vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        print(vars)

更新 2:

我还尝试了以下代码片段。简约复古风

model_dir = 'ssd_mobilenet_v2/'

meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()

sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        reader = tf.train.import_meta_graph(meta,clear_devices=True)
        reader.restore(sess,ckpt)

        vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        for var in vari:
            print(var.name,"\n")

上面的代码片段还给出了 [] 变量列表

经过一些研究,您问题的最终答案是不,他们没有。很明显,直到您意识到 saved_model 中的 variables 目录是空的。

对象检测模型动物园提供的检查点文件包含以下文件:

.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
    |-- saved_model.pb
    `-- variables

pipeline.config是保存模型的配置文件,frozen_inference_graph.pb是off-the-shelf推理。请注意,checkpointmodel.ckpt.data-00000-of-00001model.ckpt.metamodel.ckpt.index 都对应于检查点 。 ( 你可以找到一个很好的解释)

所以当你想得到可训练的变量时,唯一有用的就是saved_model目录。

Use SavedModel to save and load your model—variables, the graph, and the graph's metadata. This is a language-neutral, recoverable, hermetic serialization format that enables higher-level systems and tools to produce, consume, and transform TensorFlow models.

要恢复 SavedModel 你可以使用 api tf.saved_model.loader.load(),这个 api 包含一个名为 tags 的参数,它指定MetaGraphDef。所以如果你想得到可训练的变量,你需要在调用api.

时指定tag_constants.TRAINING

我试图调用此 api 来恢复变量,但它给了我一个错误

MetaGraphDef associated with tags 'train' could not be found in SavedModel. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: saved_model_cli

所以我执行了这个 saved_model_cli 命令来检查 SavedModel.

中可用的所有标签
#from directory saved_model
saved_model_cli show --dir . --all

输出为

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
  ...

因此 SavedModel 中没有标签 train,只有 serve。因此,此处的 SavedModel 仅用于 tensorflow 服务。这意味着当这些文件在创建时未使用标记 training 指定时,无法从这些文件中恢复训练变量。

P.S.: 以下代码是我用来恢复 SavedModel 的代码。设置tag_constants.TRAINING时无法加载完成,设置tag_constants.SERVING时加载成功但变量为空

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
  variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  print(variables)

P.P.S: 我找到了创建 SavedModel here 的脚本。可见创建SavedModel.

时确实没有train标签