如何从 TensorFlow Java 调用模型?

How to invoke model from TensorFlow Java?

以下 python 代码将 ["hello", "world"] 传递到通用句子编码器,并 returns 表示其编码表示的浮点数组。

import tensorflow as tf
import tensorflow_hub as hub

module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))

这段代码有效,但我现在想使用 Java API 做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。到目前为止,这是我得到的:

import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            Graph graph = module.graph();
            try (Session session = new Session(graph, ConfigProto.newBuilder().
                setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                setAllowSoftPlacement(true).
                build().toByteArray()))
            {
                Tensor<String> input = Tensors.create(new byte[][]
                    {
                        "hello".getBytes(StandardCharsets.UTF_8),
                        "world".getBytes(StandardCharsets.UTF_8)
                    });
                List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
                    addTarget("???").run();
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }
}

我使用 扫描模型以寻找可能的 input/output 节点。我相信输入节点是 "serving_default_inputs" 但我不知道输出节点。更重要的是,在通过 Keras 调用 python 中的代码时,我不必指定任何这些值,所以有没有办法使用 Java API?

更新:多亏了,我现在可以确认输入节点是serving_default_input,输出节点是StatefulPartitionedCall_1,但是当我将这些名称插入我得到的上述代码中:

2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
     [[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
     [[StatefulPartitionedCall_1/StatefulPartitionedCall]]
    at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
    at libtensorflow@1.15.0/org.tensorflow.Session.access0(Session.java:48)
    at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)

也就是说,我仍然无法调用模型。我错过了什么?

您可以使用 Deep Java Library

加载 TF 模型
System.setProperty("ai.djl.repository.zoo.location", "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/1.tar.gz?artifact_id=encoder");

Criteria.Builder<NDList, NDList> builder =
        Criteria.builder()
                .setTypes(NDList.class, NDList.class)
                .optArtifactId("ai.djl.localmodelzoo:encoder")
                .build();
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);

详情见https://github.com/awslabs/djl/blob/master/docs/load_model.md#load-model-from-a-url

获取名字的方式有两种:

1) 使用 Java:

您可以从存储在已保存模型包中的 org.tensorflow.proto.framework.MetaGraphDef 中读取输入和输出名称。

下面是一个关于如何提取信息的例子:

https://github.com/awslabs/djl/blob/master/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java#L149

2) 使用 python:

在tensorflow中加载保存的模型python并打印名称

loaded = tf.saved_model.load("path/to/model/")
print(list(loaded.signatures.keys()))
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)

我建议看一下Deep Java Library,它会自动处理输入、输出名称。 它支持 TensorFlow 2.1.0 并允许您加载 Keras 模型以及 TF Hub 保存的模型。查看文档 here and here

如果您在加载模型时遇到问题,请随时打开 issue

我在之后想通了。

  • 我需要使用 SavedModuleBundle.session() 而不是构建我自己的实例。这是因为加载程序初始化了图形变量。
  • 我没有将 ConfigProto 传递给 Session 构造函数,而是将其传递给 SavedModelBundle 加载程序。
  • 我需要使用 fetch() 而不是 addTarget() 来检索输出张量。

这是工作代码:

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            try (Tensor<String> input = Tensors.create(new byte[][]
                {
                    "hello".getBytes(StandardCharsets.UTF_8),
                    "world".getBytes(StandardCharsets.UTF_8)
                }))
            {
                MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
                Map<String, Shape> nameToInput = getInputToShape(metadata);
                String firstInput = nameToInput.keySet().iterator().next();

                Map<String, Shape> nameToOutput = getOutputToShape(metadata);
                String firstOutput = nameToOutput.keySet().iterator().next();

                System.out.println("input: " + firstInput);
                System.out.println("output: " + firstOutput);
                System.out.println();

                List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
                    fetch(firstOutput).run();
                for (Tensor<?> tensor : result)
                {
                    {
                        float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
                            tensor.numDimensions()];
                        tensor.copyTo(array);
                        System.out.println(Arrays.deepToString(array));
                    }
                }
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }

    /**
     * Loads a graph from a file.
     *
     * @param source the directory containing  to load from
     * @param tags   the model variant(s) to load
     * @return the graph
     * @throws NullPointerException if any of the arguments are null
     * @throws IOException          if an error occurs while reading the file
     */
    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        // 
        try
        {
            return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
                withTags(tags).
                withConfigProto(ConfigProto.newBuilder().
                    setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                    setAllowSoftPlacement(true).
                    build().toByteArray()).
                load();
        }
        catch (TensorFlowException e)
        {
            throw new IOException(e);
        }
    }

    /**
     * @param metadata the graph metadata
     * @return the first signature, or null
     */
    private SignatureDef getFirstSignature(MetaGraphDef metadata)
    {
        Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
        if (nameToSignature.isEmpty())
            return null;
        return nameToSignature.get(nameToSignature.keySet().iterator().next());
    }

    /**
     * @param metadata the graph metadata
     * @return the output signature
     */
    private SignatureDef getServingSignature(MetaGraphDef metadata)
    {
        return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an output name to its shape
     */
    protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an input name to its shape
     */
    protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }
}

我需要做同样的事情,但似乎仍然缺少很多 RE DJL 用法。例如,这之后要做什么?:

ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);

终于在DJL源码中找到了例子。关键要点是根本不要对 input/output 使用 NDList:

Criteria<String[], float[][]> criteria =
        Criteria.builder()
                .optApplication(Application.NLP.TEXT_EMBEDDING)
                .setTypes(String[].class, float[][].class)
                .optModelUrls(modelUrl)
                .build();
try (ZooModel<String[], float[][]> model = ModelZoo.loadModel(criteria);
        Predictor<String[], float[][]> predictor = model.newPredictor()) {
    return predictor.predict(inputs.toArray(new String[0]));
}

有关完整示例,请参阅 https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java