是否可以使可训练变量不可训练?
Is it possible to make a trainable variable not trainable?
我在范围内创建了一个 可训练 变量。后来进入同一个作用域,设置作用域为reuse_variables
,用get_variable
取回同一个变量。但是,我无法将变量的可训练 属性 设置为 False
。我的 get_variable
行是这样的:
weight_var = tf.get_variable('weights', trainable = False)
但是变量'weights'
还在tf.trainable_variables
的输出中。
我可以使用 get_variable
将共享变量的 trainable
标志设置为 False
吗?
我想这样做的原因是我试图在我的模型中重用从 VGG 网络预训练的低级过滤器,我想像以前一样构建图形,检索权重变量,并将 VGG 过滤器值分配给权重变量,然后在接下来的训练步骤中保持它们固定不变。
查看文档和代码后,我无法找到从 TRAINABLE_VARIABLES
.
中删除变量的方法
事情是这样的:
- 第一次调用
tf.get_variable('weights', trainable=True)
时,变量被添加到TRAINABLE_VARIABLES
的列表中。
- 第二次调用
tf.get_variable('weights', trainable=False)
时,您得到了相同的变量,但参数 trainable=False
无效,因为该变量已经存在于 TRAINABLE_VARIABLES
的列表中(并且无法从那里删除它)
第一个解决方案
当调用优化器的minimize
方法时(参见doc.),您可以将var_list=[...]
作为参数传递给您想要优化器的变量。
比如想冻结VGG中除最后两层以外的所有层,可以将最后两层的权重传入var_list
.
第二种解决方案
您可以使用 tf.train.Saver()
保存变量并稍后恢复它们(参见 this tutorial)。
- 首先,您使用所有可训练 变量训练整个 VGG 模型。您通过调用
saver.save(sess, "/path/to/dir/model.ckpt")
. 将它们保存在检查点文件中
- 然后(在另一个文件中)用 不可训练的 变量训练第二个版本。您加载之前使用
saver.restore(sess, "/path/to/dir/model.ckpt")
. 存储的变量
您可以选择在检查点文件中只保存部分变量。有关详细信息,请参阅 doc。
当您只想训练或优化预训练网络的某些层时,这就是您需要了解的内容。
TensorFlow 的 minimize
方法采用可选参数 var_list
,即要通过反向传播调整的变量列表。
如果您不指定 var_list
,图表中的任何 TF 变量都可以由优化器进行调整。当您在 var_list
中指定一些变量时,TF 会保持所有其他变量不变。
这是 jonbruner 和他的合作者使用的脚本示例。
tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)
这会找到他们之前定义的所有变量名称中包含 "g_" 的变量,将它们放入列表中,并对它们运行 ADAM 优化器。
您可以在 Quora
上找到相关答案
为了从可训练变量列表中删除变量,您可以先通过以下方式访问集合:
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
在那里,trainable_collection
包含对可训练变量集合的引用。如果从该列表中弹出元素,例如 trainable_collection.pop(0)
,您将从可训练变量中删除相应的变量,因此不会训练该变量。
虽然这适用于 pop
,但我仍在努力寻找一种方法来正确使用 remove
和正确的参数,因此我们不依赖于变量的索引。
编辑: 假设你有图中变量的名称(你可以通过检查图 protobuf 或更简单的使用 Tensorboard 来获得),你可以使用它循环遍历可训练变量列表,然后从可训练集合中删除变量。
示例:假设我想要训练名称为 "batch_normalization/gamma:0"
和 "batch_normalization/beta:0"
NOT 的变量,但它们已经添加到 TRAINABLE_VARIABLES
集合中。我能做的是:
`
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
`
这将成功地从集合中移除这两个变量,它们将不再被训练。
您可以使用 tf.get_collection_ref 来获取集合的引用而不是 tf.get_collection
我在范围内创建了一个 可训练 变量。后来进入同一个作用域,设置作用域为reuse_variables
,用get_variable
取回同一个变量。但是,我无法将变量的可训练 属性 设置为 False
。我的 get_variable
行是这样的:
weight_var = tf.get_variable('weights', trainable = False)
但是变量'weights'
还在tf.trainable_variables
的输出中。
我可以使用 get_variable
将共享变量的 trainable
标志设置为 False
吗?
我想这样做的原因是我试图在我的模型中重用从 VGG 网络预训练的低级过滤器,我想像以前一样构建图形,检索权重变量,并将 VGG 过滤器值分配给权重变量,然后在接下来的训练步骤中保持它们固定不变。
查看文档和代码后,我无法找到从 TRAINABLE_VARIABLES
.
事情是这样的:
- 第一次调用
tf.get_variable('weights', trainable=True)
时,变量被添加到TRAINABLE_VARIABLES
的列表中。 - 第二次调用
tf.get_variable('weights', trainable=False)
时,您得到了相同的变量,但参数trainable=False
无效,因为该变量已经存在于TRAINABLE_VARIABLES
的列表中(并且无法从那里删除它)
第一个解决方案
当调用优化器的minimize
方法时(参见doc.),您可以将var_list=[...]
作为参数传递给您想要优化器的变量。
比如想冻结VGG中除最后两层以外的所有层,可以将最后两层的权重传入var_list
.
第二种解决方案
您可以使用 tf.train.Saver()
保存变量并稍后恢复它们(参见 this tutorial)。
- 首先,您使用所有可训练 变量训练整个 VGG 模型。您通过调用
saver.save(sess, "/path/to/dir/model.ckpt")
. 将它们保存在检查点文件中
- 然后(在另一个文件中)用 不可训练的 变量训练第二个版本。您加载之前使用
saver.restore(sess, "/path/to/dir/model.ckpt")
. 存储的变量
您可以选择在检查点文件中只保存部分变量。有关详细信息,请参阅 doc。
当您只想训练或优化预训练网络的某些层时,这就是您需要了解的内容。
TensorFlow 的 minimize
方法采用可选参数 var_list
,即要通过反向传播调整的变量列表。
如果您不指定 var_list
,图表中的任何 TF 变量都可以由优化器进行调整。当您在 var_list
中指定一些变量时,TF 会保持所有其他变量不变。
这是 jonbruner 和他的合作者使用的脚本示例。
tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)
这会找到他们之前定义的所有变量名称中包含 "g_" 的变量,将它们放入列表中,并对它们运行 ADAM 优化器。
您可以在 Quora
上找到相关答案为了从可训练变量列表中删除变量,您可以先通过以下方式访问集合:
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
在那里,trainable_collection
包含对可训练变量集合的引用。如果从该列表中弹出元素,例如 trainable_collection.pop(0)
,您将从可训练变量中删除相应的变量,因此不会训练该变量。
虽然这适用于 pop
,但我仍在努力寻找一种方法来正确使用 remove
和正确的参数,因此我们不依赖于变量的索引。
编辑: 假设你有图中变量的名称(你可以通过检查图 protobuf 或更简单的使用 Tensorboard 来获得),你可以使用它循环遍历可训练变量列表,然后从可训练集合中删除变量。
示例:假设我想要训练名称为 "batch_normalization/gamma:0"
和 "batch_normalization/beta:0"
NOT 的变量,但它们已经添加到 TRAINABLE_VARIABLES
集合中。我能做的是:
`
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
` 这将成功地从集合中移除这两个变量,它们将不再被训练。
您可以使用 tf.get_collection_ref 来获取集合的引用而不是 tf.get_collection