TensorFlow 从文件中保存 into/loading 个图形

TensorFlow saving into/loading a graph from a file

从我目前收集到的信息来看,有几种不同的方法可以将 TensorFlow 图转储到文件中,然后将其加载到另一个程序中,但我一直无法找到明确的方法 examples/information关于他们的工作方式。我已经知道的是:

  1. 使用 tf.train.Saver() 将模型的变量保存到检查点文件 (.ckpt) 中,稍后恢复它们 (source)
  2. 将模型保存到 .pb 文件中,然后使用 tf.train.write_graph()tf.import_graph_def() (source)
  3. 将其加载回来
  4. 从 .pb 文件加载模型,重新训练它,然后使用 Bazel 将其转储到新的 .pb 文件中 (source)
  5. 冻结图形以将图形和权重一起保存(source)
  6. 使用as_graph_def()保存模型,对于weights/variables,映射成常量()

但是,关于这些不同的方法,我无法解决几个问题:

  1. 关于检查点文件,他们只保存模型训练的权重吗?是否可以将检查点文件加载到新程序中,并用于 运行 模型,或者它们只是作为在某个 time/stage 时保存模型中权重的方法?
  2. 关于tf.train.write_graph(),weights/variables是否也保存了?
  3. 关于Bazel,是否只能从.pb文件中保存into/load用于再训练?是否有一个简单的 Bazel 命令可以将图形转储到 .pb 中?
  4. 关于冻结,是否可以使用tf.import_graph_def()加载冻结图?
  5. TensorFlow 的 Android 演示从 .pb 文件加载到 Google 的 Inception 模型中。如果我想替换我自己的 .pb 文件,我该怎么做呢?我需要更改任何原生 code/methods 吗?
  6. 总的来说,所有这些方法之间到底有什么区别?或者更广泛地说,as_graph_def()/.ckpt/.pb 之间有什么区别?

简而言之,我正在寻找的是一种将图形(如各种操作等)及其 weights/variables 保存到文件中的方法,然后可以使用该文件加载图表和权重到另一个程序中,以供使用(不一定continuing/retraining)。

关于此主题的文档不是很简单,因此任何 answers/information 将不胜感激。

在 TensorFlow 中有很多方法可以解决保存模型的问题,这可能会让人有点困惑。依次回答您的每个子问题:

  1. 检查点文件(通过调用 saver.save() on a tf.train.Saver object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights. Note that calling saver.save() also produces a file containing a MetaGraphDef, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorial 生成更多详细信息。

  2. tf.train.write_graph()只写图结构;不是权重。

  3. Bazel 与读写 TensorFlow 图表无关。 (也许我误解了你的问题:请随时在评论中澄清。)

  4. 可以使用 tf.import_graph_def() 加载冻结图。在这种情况下,权重(通常)嵌入在图中,因此您不需要加载单独的检查点。

  5. 主要变化是更新输入模型的张量名称,以及从模型中获取的张量名称。在 TensorFlow Android 演示中,这将对应于传递给 TensorFlowClassifier.initializeTensorFlow().

  6. inputNameoutputName 字符串
  7. GraphDef是程序结构,在训练过程中通常不会改变。检查点是训练过程状态的快照,通常在训练过程的每一步都会发生变化。因此,TensorFlow 对这些类型的数据使用了不同的存储格式,底层 API 提供了不同的方式来保存和加载它们。更高级别的库,例如 MetaGraphDef libraries, Keras, and skflow 建立在这些机制之上,以提供更方便的方法来保存和恢复整个模型。

您可以试试下面的代码:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)