在 TensorFlow 中批量访问单个梯度的最佳方法是什么?
What's the best way to access single gradients in a batch in TensorFlow?
我目前正在分析在使用 Tensorflow 2.x 训练 CNN 的过程中梯度是如何发展的。我想要做的是将批次中的每个梯度与整个批次的梯度进行比较。目前我在每个训练步骤中使用这个简单的代码片段:
[...]
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
[...]
# One training step
# x_train is a batch of input data, y_train the corresponding labels
def train_step(model, optimizer, x_train, y_train):
# Process batch
with tf.GradientTape() as tape:
batch_predictions = model(x_train, training=True)
batch_loss = loss_object(y_train, batch_predictions)
batch_grads = tape.gradient(batch_loss, model.trainable_variables)
# Do something with gradient of whole batch
# ...
# Process each data point in the current batch
for index in range(len(x_train)):
with tf.GradientTape() as single_tape:
single_prediction = model(x_train[index:index+1], training=True)
single_loss = loss_object(y_train[index:index+1], single_prediction)
single_grad = single_tape.gradient(single_loss, model.trainable_variables)
# Do something with gradient of single data input
# ...
# Use batch gradient to update network weights
optimizer.apply_gradients(zip(batch_grads, model.trainable_variables))
train_loss(batch_loss)
train_accuracy(y_train, batch_predictions)
我的主要问题是单手计算每个梯度时计算时间激增,尽管这些计算应该已经由 Tensorflow 在计算批次梯度时完成。原因是 GradientTape
和 compute_gradients
总是 return 一个单一的梯度,无论给出的是单个数据点还是多个数据点。因此必须对每个数据点进行此计算。
我知道我可以通过使用为每个数据点计算的所有单个梯度来计算批处理的梯度来更新网络,但这在节省计算时间方面只起到了很小的作用。
是否有更有效的方法来计算单个梯度?
你可以使用梯度带的jacobian
方法得到雅可比矩阵,它会给你每个损失值的梯度:
import tensorflow as tf
# Make a random linear problem
tf.random.set_seed(0)
# Random input batch of ten four-vector examples
x = tf.random.uniform((10, 4))
# Random weights
w = tf.random.uniform((4, 2))
# Random batch label
y = tf.random.uniform((10, 2))
with tf.GradientTape() as tape:
tape.watch(w)
# Prediction
p = x @ w
# Loss
loss = tf.losses.mean_squared_error(y, p)
# Compute Jacobian
j = tape.jacobian(loss, w)
# The Jacobian gives you the gradient for each loss value
print(j.shape)
# (10, 4, 2)
# Gradient of the loss wrt the weights for the first example
tf.print(j[0])
# [[0.145728424 0.0756840706]
# [0.103099883 0.0535449386]
# [0.267220169 0.138780832]
# [0.280130595 0.145485848]]
我目前正在分析在使用 Tensorflow 2.x 训练 CNN 的过程中梯度是如何发展的。我想要做的是将批次中的每个梯度与整个批次的梯度进行比较。目前我在每个训练步骤中使用这个简单的代码片段:
[...]
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
[...]
# One training step
# x_train is a batch of input data, y_train the corresponding labels
def train_step(model, optimizer, x_train, y_train):
# Process batch
with tf.GradientTape() as tape:
batch_predictions = model(x_train, training=True)
batch_loss = loss_object(y_train, batch_predictions)
batch_grads = tape.gradient(batch_loss, model.trainable_variables)
# Do something with gradient of whole batch
# ...
# Process each data point in the current batch
for index in range(len(x_train)):
with tf.GradientTape() as single_tape:
single_prediction = model(x_train[index:index+1], training=True)
single_loss = loss_object(y_train[index:index+1], single_prediction)
single_grad = single_tape.gradient(single_loss, model.trainable_variables)
# Do something with gradient of single data input
# ...
# Use batch gradient to update network weights
optimizer.apply_gradients(zip(batch_grads, model.trainable_variables))
train_loss(batch_loss)
train_accuracy(y_train, batch_predictions)
我的主要问题是单手计算每个梯度时计算时间激增,尽管这些计算应该已经由 Tensorflow 在计算批次梯度时完成。原因是 GradientTape
和 compute_gradients
总是 return 一个单一的梯度,无论给出的是单个数据点还是多个数据点。因此必须对每个数据点进行此计算。
我知道我可以通过使用为每个数据点计算的所有单个梯度来计算批处理的梯度来更新网络,但这在节省计算时间方面只起到了很小的作用。
是否有更有效的方法来计算单个梯度?
你可以使用梯度带的jacobian
方法得到雅可比矩阵,它会给你每个损失值的梯度:
import tensorflow as tf
# Make a random linear problem
tf.random.set_seed(0)
# Random input batch of ten four-vector examples
x = tf.random.uniform((10, 4))
# Random weights
w = tf.random.uniform((4, 2))
# Random batch label
y = tf.random.uniform((10, 2))
with tf.GradientTape() as tape:
tape.watch(w)
# Prediction
p = x @ w
# Loss
loss = tf.losses.mean_squared_error(y, p)
# Compute Jacobian
j = tape.jacobian(loss, w)
# The Jacobian gives you the gradient for each loss value
print(j.shape)
# (10, 4, 2)
# Gradient of the loss wrt the weights for the first example
tf.print(j[0])
# [[0.145728424 0.0756840706]
# [0.103099883 0.0535449386]
# [0.267220169 0.138780832]
# [0.280130595 0.145485848]]