Tensorflow 2.0:None values are not supported in tf.function while trying to print last result 错误

Tensorflow 2.0 : None values are not supported error in tf.function while trying to print last result

我正在尝试在 tf.function

中打印最后一批的结果
import tensorflow as tf

def small_data():
    for i in range(10):
        yield 3, 2

data = tf.data.Dataset.from_generator(
    small_data, (tf.int32, tf.int32), )

def result(data):
    """
    Psuedo code for a model which outputs multiple layer outputs
    :param data:
    :return:
    """
    return tf.random.normal(shape=[1, 2]), tf.random.normal(shape=[1, 2]),data[0]

@tf.function
def train(dataset):
    batch_result = None
    for batch in dataset:
        batch_result = result(data)
    tf.print("Final batch result is", batch_result)


train(dataset=data)


错误

 raise ValueError("None values not supported.")

    ValueError: None values not supported.

result 函数实际上是一个 Keras 模型,它会产生不同形状的图层输出。如果我删除 batch_result=None 赋值并将 tf.print 移到循环内,它会为每个批次打印。我只想打印最后一批的结果。另外,我不确定输入循环的记录数。我也尝试了多种变体,但没有任何效果。我怎样才能在 tensorflow 2.0 中实现这一点。

你必须模仿 batch_result 的预期形式。这有效:

@tf.function
def train(dataset):
    batch_result = result(dataset.take(1))
    for batch in dataset:
        batch_result = result(data)
    tf.print("Final batch result is", batch_result)

有点老套,但这可能有用:

@tf.function
def train(dataset):
    batch_result = result(next(dataset.__iter__()))
    for batch in dataset:
        batch_result = result(data)
    tf.print("Final batch result is", batch_result)