模型(特征)不 return EagerTensor

model(features) does not return EagerTensor

我建立了一个模型,我只能使用我正在尝试调试的自定义损失进行训练。

为此,我在这里有一个简单的循环:

for (mel_specs, pred_inp), labels in train_dataset:
    enc_predictions = model((mel_specs, pred_inp))  # <--- Returns a Tensor, not an EagerTensor
    input_lengths = get_padded_length(mel_specs[:, :, 0])
    label_lengths = get_padded_length(labels)
    print(enc_predictions)
    loss_value = rnnt_loss(enc_predictions, labels, input_lengths, label_lengths)
    print(loss_value)

model就是:

model = tf.keras.Model(
    inputs=[mel_specs, pred_inp],
    outputs=[outputs]
)

问题是 model((mel_specs, pred_inp)) 只给我一个常规的 Tensor 而不是 EagerTensor,我不明白为什么。 mel_specspred_inpuEagerTensor 来自 train_datasettf.data.Dataset.

我在这里错过了什么?

环境

$ pip freeze | grep tensorflow
tensorflow==2.2.0
tensorflow-addons==0.10.0
tensorflow-datasets==3.1.0
tensorflow-estimator==2.2.0
tensorflow-metadata==0.22.2
warprnnt-tensorflow==0.1

更新:MVCE

我能够将其归结为模型的编码器部分。如果我 运行 这将失败并打印:

Calling model(x) didn't return EagerTensor
Traceback (most recent call last):
    ...
    return loss_value, tape.gradient(loss_value, model.trainable_variables)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/backprop.py", line 1042, in gradient
    flat_grad = imperative_grad.imperative_grad(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/imperative_grad.py", line 71, in imperative_grad
    return pywrap_tfe.TFE_Py_TapeGradient(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/backprop.py", line 157, in _gradient_function
    return grad_fn(mock_op, *out_grads)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/ops/math_grad.py", line 252, in _MeanGrad
    sum_grad = _SumGrad(op, grad)[0]
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/ops/math_grad.py", line 211, in _SumGrad
    output_shape_kept_dims = math_ops.reduced_shape(input_shape,
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py", line 3735, in reduced_shape
    input_shape = input_shape.numpy()
AttributeError: 'Tensor' object has no attribute 'numpy'

代码:

import numpy as np
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor


class TimeReduction(tf.keras.layers.Layer):

    def __init__(self,
                 reduction_factor,
                 batch_size=None,
                 **kwargs):
        super(TimeReduction, self).__init__(**kwargs)
        self.reduction_factor = reduction_factor
        self.batch_size = batch_size

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = self.batch_size
        if batch_size is None:
            batch_size = input_shape[0]
        max_time = input_shape[1]
        num_units = inputs.get_shape().as_list()[-1]
        outputs = inputs
        paddings = [[0, 0], [0, tf.math.floormod(max_time, self.reduction_factor)], [0, 0]]
        outputs = tf.pad(outputs, paddings)
        return tf.reshape(outputs, (batch_size, -1, num_units * self.reduction_factor))


def make_encoder_model(
    input_shape: tuple,
    out_dim: int,
    num_layers: int,
    d_model: int,
    proj_size,
    initializer=None,
    dtype=tf.float32,
    stateful: bool = False,
    dropout=0.5,
    reduction_index=1,
    reduction_factor=2,
):
    def lstm_cell():
        return tf.compat.v1.nn.rnn_cell.LSTMCell(
            d_model,
            num_proj=proj_size,
            initializer=initializer,
            dtype=dtype
        )

    batch_size = None if not stateful else 1

    inputs = tf.keras.Input(
        shape=input_shape,
        batch_size=batch_size,
        dtype=tf.float32
    )

    x = tf.keras.layers.BatchNormalization()(inputs)

    for i in range(num_layers):
        rnn_layer = tf.keras.layers.RNN(lstm_cell(), return_sequences=True, stateful=stateful)
        x = rnn_layer(x)
        x = tf.keras.layers.Dropout(dropout)(x)
        x = tf.keras.layers.LayerNormalization(dtype=dtype)(x)
        if i == reduction_index:
            x = TimeReduction(reduction_factor, batch_size=batch_size)(x)

    outputs = tf.keras.layers.Dense(out_dim)(x)

    return tf.keras.Model(
        inputs=[inputs],
        outputs=[outputs],
        name='encoder'
    )


def gradient(model, loss, inputs, y_true):
    y_true = tf.transpose(y_true, perm=(0, 2, 1))
    with tf.GradientTape() as tape:
        y_pred = model(inputs, training=True)
        loss_value = loss(y_true=y_true, y_pred=y_pred)
        return loss_value, tape.gradient(loss_value, model.trainable_variables)


def main():
    X, Y = [
        np.random.rand(100, 512),
        np.random.rand(100, 512)
    ], [[[0]*50], [[1]*50]]
    # assert len(X) == len(Y)

    encoder_model = make_encoder_model(
        input_shape=(None, 512),
        out_dim=1,
        num_layers=2,
        d_model=10,
        proj_size=23,
        dropout=0.5,
        reduction_index=1,
        reduction_factor=2
    )

    enc_dataset = tf.data.Dataset.from_generator(
        lambda: zip(X, Y),
        output_types=(tf.float32, tf.int32),
        output_shapes=([None, 512], [None, None]),
    ).batch(2)

    loss = tf.keras.losses.MeanSquaredError()

    for x, y in enc_dataset:
        from_predict = encoder_model.predict(x)
        from_call = encoder_model(x)
        if not isinstance(from_predict, np.ndarray):
            print("Calling model.predict(x) didn't return np.ndarray")
        if not isinstance(from_call, EagerTensor):
            print("Calling model(x) didn't return EagerTensor")
        loss_value, gradients = gradient(encoder_model, loss, x, y)
        print(loss_value)
        print(gradients)

    print('All done.')


if __name__ == '__main__':
    main()

为什么使用 compat.v1 中的 LSTM 单元?我想这会导致兼容性问题。

最重要的是,那些“纯 Tensorflow”RNN 单元无论如何都不能与 keras RNN 一起使用——例如,它们曾与 tf.nn.dymanic_rnn 一起使用,现在已弃用,并且仅在compat.v1 模块。

我建议您直接使用 tf.keras.layers.LSTM,因为它的速度要快得多——它允许使用高度优化的 GPU 内核。或者,您可以将 compat.v1.LSTMCell 替换为 tf.keras.layers.LSTMCell 并将其放入 RNN.