Python Beam 不能 pickle/dill 大型 Tensorflow 模型
Python Beam can't pickle/dill a large Tensorflow Model
我们正在尝试在线提供图像处理模型(在 Tensorflow 中),这样我们就不必出于速度目的而对 REST 服务或 Cloud-ML/ML-Engine 模型进行外部调用。
我们不想在每次推理时都加载模型,而是想测试是否可以为 beam.DoFn 对象的每个实例将模型加载到内存中,这样我们就可以减少模型的加载和服务时间。
例如
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
class InferenceFn(object):
def __init__(self, model_full_path,):
super(InferenceFn, self).__init__()
self.model_full_path = model_full_path
self.graph = None
self.create_graph()
def create_graph(self):
if not tf.gfile.FastGFile(self.model_full_path):
self.download_model_file()
with tf.Graph().as_default() as graph:
with tf.gfile.FastGFile(self.model_full_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
self.graph = graph
当它不是 beam.DoFn 而只是一个常规的 class 时,它可以 运行 在本地很好,但是当它转换为 DoFn 并且我尝试远程执行它时使用 Cloud Dataflow,作业失败,因为在 serialization/pickling 期间,我想相信它试图序列化整个模型
例如
Example of Error
有没有办法绕过这个问题或阻止 python/dataflow 尝试序列化模型?
是——将模型存储为 DoFn 上的一个字段需要对其进行序列化,以便将该代码发送给每个工作人员。您应该查看以下内容:
- 安排让每个工人都能使用模型文件。这在 Python dependencies document.
中针对数据流进行了描述
- 在您的 DoFn 中实现
start_bundle
方法并让它读取文件并将其存储在本地线程中。
这确保文件的内容不会在您的本地计算机上读取和腌制,而是让每个工作人员都可以使用该文件然后读入。
我们正在尝试在线提供图像处理模型(在 Tensorflow 中),这样我们就不必出于速度目的而对 REST 服务或 Cloud-ML/ML-Engine 模型进行外部调用。
我们不想在每次推理时都加载模型,而是想测试是否可以为 beam.DoFn 对象的每个实例将模型加载到内存中,这样我们就可以减少模型的加载和服务时间。
例如
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
class InferenceFn(object):
def __init__(self, model_full_path,):
super(InferenceFn, self).__init__()
self.model_full_path = model_full_path
self.graph = None
self.create_graph()
def create_graph(self):
if not tf.gfile.FastGFile(self.model_full_path):
self.download_model_file()
with tf.Graph().as_default() as graph:
with tf.gfile.FastGFile(self.model_full_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
self.graph = graph
当它不是 beam.DoFn 而只是一个常规的 class 时,它可以 运行 在本地很好,但是当它转换为 DoFn 并且我尝试远程执行它时使用 Cloud Dataflow,作业失败,因为在 serialization/pickling 期间,我想相信它试图序列化整个模型
例如 Example of Error
有没有办法绕过这个问题或阻止 python/dataflow 尝试序列化模型?
是——将模型存储为 DoFn 上的一个字段需要对其进行序列化,以便将该代码发送给每个工作人员。您应该查看以下内容:
- 安排让每个工人都能使用模型文件。这在 Python dependencies document. 中针对数据流进行了描述
- 在您的 DoFn 中实现
start_bundle
方法并让它读取文件并将其存储在本地线程中。
这确保文件的内容不会在您的本地计算机上读取和腌制,而是让每个工作人员都可以使用该文件然后读入。