Internal error: Tried to take gradients (or similar) of a variable without handle data on running TF HUB BERT inside tf.GradientTape

Internal error: Tried to take gradients (or similar) of a variable without handle data on running TF HUB BERT inside tf.GradientTape

我正在尝试在 Tensorflow 2.4 中的持久梯度带内训练 bert_en_uncased_L-12_H-768_A-12 TF HUB 模型。以下是我的代码的简化版本。

import tensorflow as tf
import tensorflow_hub as hub

input_mask = tf.keras.layers.Input(shape=4, dtype=tf.int32)
input_word_ids = tf.keras.layers.Input(shape=4, dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=4, dtype=tf.int32)

bert = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",
    trainable=True)({"input_mask": input_mask, 'input_word_ids': input_type_ids, "input_type_ids": input_type_ids})

dense = tf.keras.layers.Dense(units=1)(bert['pooled_output'])

encode = tf.keras.models.Model([input_mask, input_word_ids, input_type_ids], dense)
import numpy as np

data = np.zeros((1, 4))


@tf.function
def run():
    with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
        tape.watch(encode.trainable_weights)
        encode([data, data, data], training=True)


run()

错误

  raise ValueError("Internal error: Tried to take gradients (or similar) "

    ValueError: Internal error: Tried to take gradients (or similar) of a variable without handle data:
    Tensor("StatefulPartitionedCall:1079", dtype=resource)

此错误仅在

时出现

我认为没有必要使用 persistent=True,你应该使用 False。通常,当我们需要计算 tape 范围内的损失时,它被设置为 True,以便我们可以计算范围外的梯度,src。在你上面的代码示例中,我认为你不需要这个。

您的代码中的另一个错字可能需要更正。它有一个错误的输入映射。

bert = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",
    trainable=True)(
        {
           "input_mask": input_mask, 
           'input_word_ids': input_word_ids,  # < ---------
           "input_type_ids": input_type_ids
         }
     )

运行 具有这些更改的代码

import numpy as np

data = np.zeros((1, 4))

@tf.function
def run():
    with tf.GradientTape( watch_accessed_variables=False) as tape:
        tape.watch(encode.trainable_weights)
        y = encode([data, data, data], training=True)
    tf.print(y)

run()

# [[-0.799545228]]

在保存和加载 SavedModel 时,为 SavedModel 使用持久性 GradientTapes 需要 TensorFlow 2.5+。请关注 https://github.com/tensorflow/hub/issues/622 以获取有关 TF2.5 发布的更新以及针对 BERT 等更新的 SavedModels

M. Innat 的回答解释了如何使用标准的非持久性 GradientTapes 来避免该问题。