是否可以使可训练变量不可训练?

Is it possible to make a trainable variable not trainable?

我在范围内创建了一个 可训练 变量。后来进入同一个作用域,设置作用域为reuse_variables,用get_variable取回同一个变量。但是,我无法将变量的可训练 属性 设置为 False。我的 get_variable 行是这样的:

weight_var = tf.get_variable('weights', trainable = False)

但是变量'weights'还在tf.trainable_variables的输出中。

我可以使用 get_variable 将共享变量的 trainable 标志设置为 False 吗?

我想这样做的原因是我试图在我的模型中重用从 VGG 网络预训练的低级过滤器,我想像以前一样构建图形,检索权重变量,并将 VGG 过滤器值分配给权重变量,然后在接下来的训练步骤中保持它们固定不变。

查看文档和代码后,我无法找到从 TRAINABLE_VARIABLES.

中删除变量的方法

事情是这样的:

  • 第一次调用tf.get_variable('weights', trainable=True)时,变量被添加到TRAINABLE_VARIABLES的列表中。
  • 第二次调用 tf.get_variable('weights', trainable=False) 时,您得到了相同的变量,但参数 trainable=False 无效,因为该变量已经存在于 TRAINABLE_VARIABLES 的列表中(并且无法从那里删除它)

第一个解决方案

当调用优化器的minimize方法时(参见doc.),您可以将var_list=[...]作为参数传递给您想要优化器的变量。

比如想冻结VGG中除最后两层以外的所有层,可以将最后两层的权重传入var_list.

第二种解决方案

您可以使用 tf.train.Saver() 保存变量并稍后恢复它们(参见 this tutorial)。

  • 首先,您使用所有可训练 变量训练整个 VGG 模型。您通过调用 saver.save(sess, "/path/to/dir/model.ckpt").
  • 将它们保存在检查点文件中
  • 然后(在另一个文件中)用 不可训练的 变量训练第二个版本。您加载之前使用 saver.restore(sess, "/path/to/dir/model.ckpt").
  • 存储的变量

您可以选择在检查点文件中只保存部分变量。有关详细信息,请参阅 doc

当您只想训练或优化预训练网络的某些层时,这就是您需要了解的内容。

TensorFlow 的 minimize 方法采用可选参数 var_list,即要通过反向传播调整的变量列表。

如果您不指定 var_list,图表中的任何 TF 变量都可以由优化器进行调整。当您在 var_list 中指定一些变量时,TF 会保持所有其他变量不变。

这是 jonbruner 和他的合作者使用的脚本示例。

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

这会找到他们之前定义的所有变量名称中包含 "g_" 的变量,将它们放入列表中,并对它们运行 ADAM 优化器。

您可以在 Quora

上找到相关答案

为了从可训练变量列表中删除变量,您可以先通过以下方式访问集合: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 在那里,trainable_collection 包含对可训练变量集合的引用。如果从该列表中弹出元素,例如 trainable_collection.pop(0),您将从可训练变量中删除相应的变量,因此不会训练该变量。

虽然这适用于 pop,但我仍在努力寻找一种方法来正确使用 remove 和正确的参数,因此我们不依赖于变量的索引。

编辑: 假设你有图中变量的名称(你可以通过检查图 protobuf 或更简单的使用 Tensorboard 来获得),你可以使用它循环遍历可训练变量列表,然后从可训练集合中删除变量。 示例:假设我想要训练名称为 "batch_normalization/gamma:0""batch_normalization/beta:0" NOT 的变量,但它们已经添加到 TRAINABLE_VARIABLES 集合中。我能做的是: `

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

` 这将成功地从集合中移除这两个变量,它们将不再被训练。

您可以使用 tf.get_collection_ref 来获取集合的引用而不是 tf.get_collection