ValueError: No gradients provided for any variable - Keras Tensorflow 2.0

ValueError: No gradients provided for any variable - Keras Tensorflow 2.0

我正在尝试在 TensorFlow 网站上关注 this example,但它不起作用。

这是我的代码:

import tensorflow as tf

def vectorize(vector_like):
    return tf.convert_to_tensor(vector_like)

def batchify(vector):
    '''Make a batch out of a single example'''
    return vectorize([vector])

data = [(batchify([0]), batchify([0, 0, 0])), (batchify([1]), batchify([0, 0, 0])), (batchify([2]), batchify([0, 0, 0]))]
num_hidden = 5
num_classes = 3

opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
loss_fn = lambda: tf.keras.backend.cast(tf.keras.losses.mse(model(input), output), tf.float32)
var_list_fn = lambda: model.trainable_weights
for input, output in data:
    opt.minimize(loss_fn, var_list_fn)

有一段时间,我收到了关于损失函数数据类型错误(int 而不是 float)的警告,这就是我将转换添加到损失函数的原因。

我得到的不是网络训练,而是错误:

ValueError: No gradients provided for any variable: ['sequential/dense/kernel:0', 'sequential/dense/bias:0', 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'].

为什么渐变没有通过?我做错了什么?

如果你想在 TF2 中操作渐变,你需要使用 GradientTape。例如以下作品。


opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))

with tf.GradientTape() as tape:
  loss = tf.keras.backend.mean(tf.keras.losses.mse(model(input),tf.cast(output, tf.float32)))

gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))

编辑:

您实际上可以通过进行以下更改来使示例工作。

  • 仅对输出使用 cast 而不是完整的 loss_fn(注意我也在做 mean,因为我们优化了 w.r.t 损失均值)

"work",我的意思是它不会抱怨。但是您需要进一步调查以确保它按预期工作。

loss_fn = lambda: tf.keras.backend.mean(tf.keras.losses.mse(model(input), tf.cast(output, tf.float32)))
var_list_fn = lambda: model.trainable_weights

opt.minimize(loss_fn, var_list_fn)