在 Tensorflow 中从文件导入图形

Import Graph From File in Tensorflow

我正在尝试从位于此处的 DeepFovea 项目导入 Facebook 发布的 protobuf:https://raw.githubusercontent.com/facebookresearch/DeepFovea/master/input_graph.pb

这是我的代码:

import tensorflow.compat.v1 as tf
from tensorflow.python.platform import gfile

tf.GraphDef.FromString(tf.gfile.Open("./input_graph.pb",'rb').read())

我收到这个错误:

google.protobuf.message.DecodeError: Error parsing message

我应该以不同的方式加载这个 protobuf 吗?

经过多次谷歌搜索后,发现您需要像这样解析它:

from google.protobuf import text_format

with tf.gfile.GFile(graph_filename, "rb") as f:
    graph_def = tf.GraphDef()
    graph_str = f.read()
    text_format.Merge(graph_str, graph_def)

来自此处的代码示例:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/test_util_test.py#L84