从另一个 class 调用时,Tensorflow 给出“<tensor> 不是该图的元素”错误

Tensorflow gives "<tensor> is not an element of this graph" error when calling from another class

首先,我已经访问过这里和其他网站的所有类似主题,但是 none 对我的情况有效。

假设我有一个 class 处理加载模型和预测:

class MyModel():
  def __init__(self):
    pass

  def load_model(self, model_path):
    self.model = tf.keras.models.load_model(model_path)

  def predict(self, img):
    return self.model.predict(img)

现在,我在另一个文件中有另一个 class,它调用 MyModel:

from mymodel import MyModel
class MyDetector():
    def __init__(self):
        self.detector = MyModel()
        self.detector.load_model('mymodel.h5')

    def detect(self, img: numpy.ndarray):        
        return self.detector.predict(img)

然而,这会抛出一个错误 <tensor> is not an element of this graph。我已经尝试了所有可用的 tf.Graph.as_default() 相关答案,但没有任何改变。最常见的建议是修改模型加载和预测部分如下:

def load_model(self, model_path):
    global model
    model = tf.keras.models.load_model(model_path)
    global graph
    graph = tf.get_default_graph() 

def predict(self, img):
    with graph.as_default():
      preds = model.predict(img)
    return preds

这仍然无济于事,因为所有其他建议也可用于:https://github.com/keras-team/keras/issues/6462

我认为我的案例与那些已经解决过类似案例的案例不同,因为我尝试从一个完全不同的 class 文件中调用模型 class。我的 Tensorflow 版本是 2.6.0。有人可以更好地了解如何解决它吗?

更新

实际情况是我正在使用 gRPC 与远程服务器通信以进行模型推理。为此,我使用了一个非常简单的基于 gRPC 的客户端-服务器通信。我的客户端代码定义如下(client.py):

import cv2
import grpc 
import pybase64

import protos.mydetector_pb2 as mydetector_pb2
import protos.mydetector_pb2_grpc as mydetector_pb2_grpc 

# open a gRPC channel
channel = grpc.insecure_channel('[::]:50051')
stub = mydetector_pb2_grpc.MyDetectionServiceStub(channel)

img = cv2.imread('test.jpg')
retval, buffer = cv2.imencode('.jpg', img)
b64img = pybase64.b64encode(buffer)

print('\nSending single request to port 50051...')
request = mydetector_pb2.MyDetectionRequest(image=b64img)

response = stub.detect(request)

那么在接收服务器端,主服务器实现如下(server.py):

import grpc
from concurrent import futures
import protos.mydetector_pb2_grpc as reid_grpc
import MyDetectionService

MAX_MESSAGE_IN_MB = 10

options = [
    ('grpc.max_send_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024),
    ('grpc.max_receive_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024)
]

server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=options)
reid_grpc.add_MyDetectionServiceServicer_to_server(MyDetectionService(), server)

print('Starting server. Listening on port 50051.')
server.add_insecure_port('[::]:50051')
server.start()

MyDetectionServiceclass实现如下:

import protos.mydetector_pb2 as mydetector
import protos.mydetector_pb2_grpc as mydetector_grpc
from mydetector.service.detector import MyDetector
from mydetector.utils import img_converter

import cv2
import numpy as np

class MyDetectionService(mydetector_grpc.MyDetectionServiceServicer):
def __init__(self):
    self.detector = MyDetector()

def detect(self, request, context):
    print('detecting on received image...')
    encoded_img = request.image
    img = img_converter(encoded_img)
    img = cv2.resize(img, (240, 240))
    img2 = np.expand_dims(img, axis=0)
    result = self.detector.detect(img2)
    return mydetector.MyDetectionResponse(ans=result)

其中,MyDetector class 实现如上所示。

我发现如果我不使用基于 gRPC 的服务器-客户端通信,而是从 class 以外的任何其他常规方式调用 MyDetector,一切都会顺利进行。但是,当我通过 gRPC 从客户端发送图像时,它成功地在 MyDetector class 中加载了模型(我可以调用 model.summary() 来获取模型的完整描述),但是detect 函数失败。

重要提示:根据可用信息here,我相信每次发出 gRPC 请求时,它都会使用自己的 Tensorflow 会话创建新线程,这就是这里的主要问题。但是,即使按照该站点上描述的所有说明进行操作,我仍然无法使其正常工作。

我在 https://colab.research.google.com/drive/1OaH7ZoAsY_V1sMUmc1NumWmJPnNr_54F?usp=sharing 上复制了问题的意图。 Colab 成功使用 MyDetector 预测 MNIST 图像。

作为本练习的一部分,我看到这里发生了几件事:

  1. MyModel.model_path 未定义。尽管在 MyModel.load_model 中提供了 model_path 作为参数,但它未被使用。换句话说,我猜 load_model 部分或问题描述中有错字。

此外,这里有几点想法:

  1. tf.get_default_graph() 在 TensorFlow 2.6.0 中不起作用。 TF 2.6 具有类似的 tf.compat.v1.get_default_graph()。我会 运行 强烈推荐 运行ning tf.version 以确认执行代码确实使用 2.6.0.

  2. 如果可以,请将 MyDetector 添加到 MyModel 文件中。如果有效,那么您就知道问题与某些代码位于单独的文件中这一事实有关,这可能有助于排除故障。

基于以上所述,我建议在启用 Eager Execution 的情况下调试问题,看看是否可以让事情正常进行。

您的服务器加载模型的方式与用于接收客户端请求和响应的方式不同graph/session。将您的 MyModel class 修改为应该有效:

class MyModel():
  def __init__(self):
    pass

  def load_model(self, model_path):
    self.graph = tf.compat.v1.get_default_graph()

    with self.graph.as_default():
      self.model = tf.keras.models.load_model(model_path)

    self.sess = tf.compat.v1.keras.backend.get_session()

  def predict(self, img):
    with self.graph.as_default():
      try:
        preds= self.model.predict(img)
      except tf.errors.FailedPreconditionError:
        tf.compat.v1.keras.backend.set_session(self.sess)
        preds= self.model.predict(img)

    return preds