如何让 tensorflow.keras.model.predict() 输出生成器?

How to make tensorflow.keras.model.predict() output a generator?

我知道,tensorflow.keras.model.predict() 可以接收一个 generator 或一个 numpy array,但是这个函数的输出总是一个 numpy array。有时输出 array 会太大以至于出现 OOM Memory Errors

调用tensorflow.keras.model.predict() 时有没有输出生成器的建议??

制作生成器包装器:

def generator_predict(data_gen):
    for data in data_gen:
        yield tensorflow.keras.model.predict(data)

那么当你使用它时:

predict_gen = generator_predict(data_gen)

predicted_output = next(predict_gen)
...

类似的东西应该可以正常工作。