如何在检查点中列出某些变量?

How do I list certain variables in the checkpoint?

我正在使用自动编码器。我的检查点包含网络的完整状态(即编码器、解码器、优化器等)。我想愚弄编码。因此,在我的评估模式中,我只需要网络的解码器部分。

我如何才能只从现有检查点读取几个特定变量,以便我可以在另一个模型中重用它们的值?

checkpoint_utils.py 中有 list_variables 方法,可让您查看所有已保存的变量。

但是,对于您的用例,使用 Saver 可能更容易恢复。如果您在保存检查点时知道变量的名称,则可以创建一个新的保存器,并告诉它将这些名称初始化为新的 Variable 对象(可能具有不同的名称)。这在 CIFAR 示例中用于 select 恢复 Howto

中的 subset of variables. See Choosing which Variables to Save and Restore

另一种方式,将打印所有检查点张量(或仅一个,如果指定)及其内容:

from tensorflow.python.tools import inspect_checkpoint as inch
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True)
"""
Args:
  file_name: Name of the checkpoint file.
  tensor_name: Name of the tensor in the checkpoint file to print.
  all_tensors: Boolean indicating whether to print all tensors.
"""

它会一直打印张量的内容。

并且,在我们使用它的同时,这里是如何使用 checkpoint_utils.py(由先前的答案建议):

from tensorflow.contrib.framework.python.framework import checkpoint_utils

var_list = checkpoint_utils.list_variables('./')
for v in var_list:
    print(v)

您可以使用

查看 .ckpt 文件中保存的变量
import tensorflow as tf

variables_in_checkpoint = tf.train.list_variables('path.ckpt')

print("Variables found in checkpoint file",variables_in_checkpoint)