如何让 tensorflow.keras.model.predict() 输出生成器?
How to make tensorflow.keras.model.predict() output a generator?
我知道,tensorflow.keras.model.predict()
可以接收一个 generator
或一个 numpy array
,但是这个函数的输出总是一个 numpy array
。有时输出 array
会太大以至于出现 OOM Memory Errors
。
调用tensorflow.keras.model.predict()
时有没有输出生成器的建议??
制作生成器包装器:
def generator_predict(data_gen):
for data in data_gen:
yield tensorflow.keras.model.predict(data)
那么当你使用它时:
predict_gen = generator_predict(data_gen)
predicted_output = next(predict_gen)
...
类似的东西应该可以正常工作。
我知道,tensorflow.keras.model.predict()
可以接收一个 generator
或一个 numpy array
,但是这个函数的输出总是一个 numpy array
。有时输出 array
会太大以至于出现 OOM Memory Errors
。
调用tensorflow.keras.model.predict()
时有没有输出生成器的建议??
制作生成器包装器:
def generator_predict(data_gen):
for data in data_gen:
yield tensorflow.keras.model.predict(data)
那么当你使用它时:
predict_gen = generator_predict(data_gen)
predicted_output = next(predict_gen)
...
类似的东西应该可以正常工作。