如何从 TensorFlow Java 调用模型?
How to invoke model from TensorFlow Java?
以下 python 代码将 ["hello", "world"]
传递到通用句子编码器,并 returns 表示其编码表示的浮点数组。
import tensorflow as tf
import tensorflow_hub as hub
module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))
这段代码有效,但我现在想使用 Java API 做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。到目前为止,这是我得到的:
import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
Graph graph = module.graph();
try (Session session = new Session(graph, ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()))
{
Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
});
List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
addTarget("???").run();
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
我使用 扫描模型以寻找可能的 input/output 节点。我相信输入节点是 "serving_default_inputs" 但我不知道输出节点。更重要的是,在通过 Keras 调用 python 中的代码时,我不必指定任何这些值,所以有没有办法使用 Java API?
更新:多亏了,我现在可以确认输入节点是serving_default_input
,输出节点是StatefulPartitionedCall_1
,但是当我将这些名称插入我得到的上述代码中:
2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
[[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
[[StatefulPartitionedCall_1/StatefulPartitionedCall]]
at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
at libtensorflow@1.15.0/org.tensorflow.Session.access0(Session.java:48)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)
也就是说,我仍然无法调用模型。我错过了什么?
您可以使用 Deep Java Library
加载 TF 模型
System.setProperty("ai.djl.repository.zoo.location", "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/1.tar.gz?artifact_id=encoder");
Criteria.Builder<NDList, NDList> builder =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optArtifactId("ai.djl.localmodelzoo:encoder")
.build();
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
详情见https://github.com/awslabs/djl/blob/master/docs/load_model.md#load-model-from-a-url
获取名字的方式有两种:
1) 使用 Java:
您可以从存储在已保存模型包中的 org.tensorflow.proto.framework.MetaGraphDef
中读取输入和输出名称。
下面是一个关于如何提取信息的例子:
2) 使用 python:
在tensorflow中加载保存的模型python并打印名称
loaded = tf.saved_model.load("path/to/model/")
print(list(loaded.signatures.keys()))
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)
我建议看一下Deep Java Library,它会自动处理输入、输出名称。
它支持 TensorFlow 2.1.0 并允许您加载 Keras 模型以及 TF Hub 保存的模型。查看文档 here and here
如果您在加载模型时遇到问题,请随时打开 issue。
我在之后想通了。
- 我需要使用
SavedModuleBundle.session()
而不是构建我自己的实例。这是因为加载程序初始化了图形变量。
- 我没有将
ConfigProto
传递给 Session
构造函数,而是将其传递给 SavedModelBundle
加载程序。
- 我需要使用
fetch()
而不是 addTarget()
来检索输出张量。
这是工作代码:
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
try (Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
}))
{
MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
Map<String, Shape> nameToInput = getInputToShape(metadata);
String firstInput = nameToInput.keySet().iterator().next();
Map<String, Shape> nameToOutput = getOutputToShape(metadata);
String firstOutput = nameToOutput.keySet().iterator().next();
System.out.println("input: " + firstInput);
System.out.println("output: " + firstOutput);
System.out.println();
List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
fetch(firstOutput).run();
for (Tensor<?> tensor : result)
{
{
float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
tensor.numDimensions()];
tensor.copyTo(array);
System.out.println(Arrays.deepToString(array));
}
}
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
/**
* Loads a graph from a file.
*
* @param source the directory containing to load from
* @param tags the model variant(s) to load
* @return the graph
* @throws NullPointerException if any of the arguments are null
* @throws IOException if an error occurs while reading the file
*/
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
//
try
{
return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
withTags(tags).
withConfigProto(ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()).
load();
}
catch (TensorFlowException e)
{
throw new IOException(e);
}
}
/**
* @param metadata the graph metadata
* @return the first signature, or null
*/
private SignatureDef getFirstSignature(MetaGraphDef metadata)
{
Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
if (nameToSignature.isEmpty())
return null;
return nameToSignature.get(nameToSignature.keySet().iterator().next());
}
/**
* @param metadata the graph metadata
* @return the output signature
*/
private SignatureDef getServingSignature(MetaGraphDef metadata)
{
return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
}
/**
* @param metadata the graph metadata
* @return a map from an output name to its shape
*/
protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
{
Map<String, Shape> result = new HashMap<>();
SignatureDef servingDefault = getServingSignature(metadata);
for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
{
TensorShapeProto shapeProto = entry.getValue().getTensorShape();
List<Dim> dimensions = shapeProto.getDimList();
long firstDimension = dimensions.get(0).getSize();
long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
Shape shape = Shape.make(firstDimension, remainingDimensions);
result.put(entry.getValue().getName(), shape);
}
return result;
}
/**
* @param metadata the graph metadata
* @return a map from an input name to its shape
*/
protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
{
Map<String, Shape> result = new HashMap<>();
SignatureDef servingDefault = getServingSignature(metadata);
for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
{
TensorShapeProto shapeProto = entry.getValue().getTensorShape();
List<Dim> dimensions = shapeProto.getDimList();
long firstDimension = dimensions.get(0).getSize();
long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
Shape shape = Shape.make(firstDimension, remainingDimensions);
result.put(entry.getValue().getName(), shape);
}
return result;
}
}
我需要做同样的事情,但似乎仍然缺少很多 RE DJL 用法。例如,这之后要做什么?:
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
终于在DJL源码中找到了例子。关键要点是根本不要对 input/output 使用 NDList:
Criteria<String[], float[][]> criteria =
Criteria.builder()
.optApplication(Application.NLP.TEXT_EMBEDDING)
.setTypes(String[].class, float[][].class)
.optModelUrls(modelUrl)
.build();
try (ZooModel<String[], float[][]> model = ModelZoo.loadModel(criteria);
Predictor<String[], float[][]> predictor = model.newPredictor()) {
return predictor.predict(inputs.toArray(new String[0]));
}
以下 python 代码将 ["hello", "world"]
传递到通用句子编码器,并 returns 表示其编码表示的浮点数组。
import tensorflow as tf
import tensorflow_hub as hub
module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))
这段代码有效,但我现在想使用 Java API 做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。到目前为止,这是我得到的:
import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
Graph graph = module.graph();
try (Session session = new Session(graph, ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()))
{
Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
});
List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
addTarget("???").run();
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
我使用
更新:多亏了serving_default_input
,输出节点是StatefulPartitionedCall_1
,但是当我将这些名称插入我得到的上述代码中:
2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
[[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
[[StatefulPartitionedCall_1/StatefulPartitionedCall]]
at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
at libtensorflow@1.15.0/org.tensorflow.Session.access0(Session.java:48)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)
也就是说,我仍然无法调用模型。我错过了什么?
您可以使用 Deep Java Library
加载 TF 模型System.setProperty("ai.djl.repository.zoo.location", "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/1.tar.gz?artifact_id=encoder");
Criteria.Builder<NDList, NDList> builder =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optArtifactId("ai.djl.localmodelzoo:encoder")
.build();
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
详情见https://github.com/awslabs/djl/blob/master/docs/load_model.md#load-model-from-a-url
获取名字的方式有两种:
1) 使用 Java:
您可以从存储在已保存模型包中的 org.tensorflow.proto.framework.MetaGraphDef
中读取输入和输出名称。
下面是一个关于如何提取信息的例子:
2) 使用 python:
在tensorflow中加载保存的模型python并打印名称
loaded = tf.saved_model.load("path/to/model/")
print(list(loaded.signatures.keys()))
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)
我建议看一下Deep Java Library,它会自动处理输入、输出名称。 它支持 TensorFlow 2.1.0 并允许您加载 Keras 模型以及 TF Hub 保存的模型。查看文档 here and here
如果您在加载模型时遇到问题,请随时打开 issue。
我在
- 我需要使用
SavedModuleBundle.session()
而不是构建我自己的实例。这是因为加载程序初始化了图形变量。 - 我没有将
ConfigProto
传递给Session
构造函数,而是将其传递给SavedModelBundle
加载程序。 - 我需要使用
fetch()
而不是addTarget()
来检索输出张量。
这是工作代码:
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
try (Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
}))
{
MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
Map<String, Shape> nameToInput = getInputToShape(metadata);
String firstInput = nameToInput.keySet().iterator().next();
Map<String, Shape> nameToOutput = getOutputToShape(metadata);
String firstOutput = nameToOutput.keySet().iterator().next();
System.out.println("input: " + firstInput);
System.out.println("output: " + firstOutput);
System.out.println();
List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
fetch(firstOutput).run();
for (Tensor<?> tensor : result)
{
{
float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
tensor.numDimensions()];
tensor.copyTo(array);
System.out.println(Arrays.deepToString(array));
}
}
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
/**
* Loads a graph from a file.
*
* @param source the directory containing to load from
* @param tags the model variant(s) to load
* @return the graph
* @throws NullPointerException if any of the arguments are null
* @throws IOException if an error occurs while reading the file
*/
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
//
try
{
return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
withTags(tags).
withConfigProto(ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()).
load();
}
catch (TensorFlowException e)
{
throw new IOException(e);
}
}
/**
* @param metadata the graph metadata
* @return the first signature, or null
*/
private SignatureDef getFirstSignature(MetaGraphDef metadata)
{
Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
if (nameToSignature.isEmpty())
return null;
return nameToSignature.get(nameToSignature.keySet().iterator().next());
}
/**
* @param metadata the graph metadata
* @return the output signature
*/
private SignatureDef getServingSignature(MetaGraphDef metadata)
{
return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
}
/**
* @param metadata the graph metadata
* @return a map from an output name to its shape
*/
protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
{
Map<String, Shape> result = new HashMap<>();
SignatureDef servingDefault = getServingSignature(metadata);
for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
{
TensorShapeProto shapeProto = entry.getValue().getTensorShape();
List<Dim> dimensions = shapeProto.getDimList();
long firstDimension = dimensions.get(0).getSize();
long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
Shape shape = Shape.make(firstDimension, remainingDimensions);
result.put(entry.getValue().getName(), shape);
}
return result;
}
/**
* @param metadata the graph metadata
* @return a map from an input name to its shape
*/
protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
{
Map<String, Shape> result = new HashMap<>();
SignatureDef servingDefault = getServingSignature(metadata);
for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
{
TensorShapeProto shapeProto = entry.getValue().getTensorShape();
List<Dim> dimensions = shapeProto.getDimList();
long firstDimension = dimensions.get(0).getSize();
long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
Shape shape = Shape.make(firstDimension, remainingDimensions);
result.put(entry.getValue().getName(), shape);
}
return result;
}
}
我需要做同样的事情,但似乎仍然缺少很多 RE DJL 用法。例如,这之后要做什么?:
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
终于在DJL源码中找到了例子。关键要点是根本不要对 input/output 使用 NDList:
Criteria<String[], float[][]> criteria =
Criteria.builder()
.optApplication(Application.NLP.TEXT_EMBEDDING)
.setTypes(String[].class, float[][].class)
.optModelUrls(modelUrl)
.build();
try (ZooModel<String[], float[][]> model = ModelZoo.loadModel(criteria);
Predictor<String[], float[][]> predictor = model.newPredictor()) {
return predictor.predict(inputs.toArray(new String[0]));
}