为什么损失函数在 tf.GradientTape 块内部而梯度计算在外部?

Why loss function is inside tf.GradientTape block and gradients calculation is outside?

我是 Tensorflow 的新人。在教科书示例中,我看到以下代码旨在使用 Tensorflow 2.x API 训练简单的线性模型:

m = tf.Variable(0.)
b = tf.Variable(0.)
def predict_y_value(x):
    y = m * x + b
    return y
def squared_error(y_pred, y_true):
    return tf.reduce_mean(tf.square(y_pred - y_true))
learning_rate = 0.05
steps = 500
for i in range(steps):
    with tf.GradientTape() as tape:
        predictions = predict_y_value(x_train)
        loss = squared_error(predictions, y_train)
    gradients = tape.gradient(loss, [m, b])
    m.assign_sub(gradients[0] * learning_rate)
    b.assign_sub(gradients[1] * learning_rate)
print ("m: %f, b: %f" % (m.numpy(), b.numpy()))

为什么需要将损失函数的定义包含在块with tf.GradientTape() as tape中,但是gradients = tape.gradient(loss, [m, b])代码行在with块之外?

我知道它可能是 Python 语言特定的,但我似乎不清楚这种结构。 Python上下文管理器在这里的作用是什么?

来自 tensorflow 文档,

By default GradientTape will automatically watch any trainable variables that are accessed inside the context.

直观地说,这种方法大大提高了灵活性。例如,它允许您编写(伪)代码如下:

inputs, labels = get_training_batch()
inputs_preprocessed = some_tf_ops(inputs)
with tf.GradientTape() as tape:
    pred = model(inputs_preprocessed)
    loss = compute_loss(labels, pred)

grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

# For example, let's attach a model that takes the above model's output as input
next_step_inputs, next_step_labels = process(pred)

with tf.GradientTape() as tape:
    pred = another_model(next_step_inputs)
    another_loss = compute_loss(next_step_labels, pred)

grads = tape.gradient(another_loss, another_model.trainable_variables)
optimizer.apply_gradients(zip(grads, another_model.trainable_variables))

上面的例子可能看起来很复杂,但它解释了需要极大灵活性的极端情况。

  1. 我们不希望some_tf_opsprocess在梯度计算中发挥作用,因为它们是预处理步骤。

  2. 我们要计算多个模型的梯度,具有某种关系

一个实际的例子是训练 GAN,尽管更简单的实现是可能的。

tape.gradient 放在 TapeGradient() 之外会重置上下文并为垃圾收集器释放资源。

注2等效示例:

with tf.GradientTape() as t:
  loss = loss_fn()
with tf.GradientTape() as t:
  loss += other_loss_fn()
t.gradient(loss, ...)         # Only differentiates other_loss_fn, not loss_fn

下面等同于上面

with tf.GradientTape() as t:
  loss = loss_fn()
  t.reset()
  loss += other_loss_fn()
t.gradient(loss, ...)         # Only differentiates other_loss_fn, not loss_fn 

source