"No Operation named [input] in the Graph" 在 Java

"No Operation named [input] in the Graph" in Java

在 Google 的机器学习速成课程 this Colab exeercise 之后,我在 Python 中为 MNIST 数据库生成了一个模型。代码如下所示:

import pandas as pd
import tensorflow as tf


def create_model(my_learning_rate):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.Input(shape=(28, 28), name='input'))
    model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
    model.add(tf.keras.layers.Dense(units=256, activation='relu'))
    model.add(tf.keras.layers.Dense(units=128, activation='relu'))
    model.add(tf.keras.layers.Dropout(rate=0.2))
    model.add(tf.keras.layers.Dense(units=10, activation='softmax', name='output'))
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=my_learning_rate),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model


def train_model(model, train_features, train_label, epochs,
                batch_size=None, validation_split=0.1):
    history = model.fit(x=train_features, y=train_label, batch_size=batch_size,
                        epochs=epochs, shuffle=True,
                        validation_split=validation_split)
    epochs = history.epoch
    hist = pd.DataFrame(history.history)
    return epochs, hist


if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train_normalized = x_train / 255.0
    x_test_normalized = x_test / 255.0

    learning_rate = 0.003
    epochs = 50
    batch_size = 4000
    validation_split = 0.2

    my_model = create_model(learning_rate)
    epochs, hist = train_model(my_model, x_train_normalized, y_train,
                               epochs, batch_size, validation_split)

    my_model.save('my_model')

模型已按原样保存到“my_model”文件夹中。现在我在我的 Java 程序中再次加载它:

public class HelloTensorFlow {
    public static void main(final String[] args) {
        final String filePath = Paths.get("my_model").toAbsolutePath().toString();
        try (final SavedModelBundle b = SavedModelBundle.load(filePath, "serve")) {
            final Session sess = b.session();

            final Tensor<Float> x = Tensor.create(new float[1][28 * 28], Float.class);
            final List<Tensor<?>> run = sess.runner()
                    .feed("input", x)
                    .fetch("output")
                    .run();

            final float[] y = run.get(0).copyTo(new float[1]);
            System.out.println(y[0]);
        }
    }
}

模型已加载,但运行器不工作。当我执行该程序时,我得到“图中没有名为 [input] 的操作”,即使我的输入有这个名称。我究竟做错了什么。我有最新的 TensorFlow 版本:2.3.0 (Python) 和 1.15.0 (Java).

我解决了。 TensorFlow 2 似乎有奇怪的命名方案,但使用 MetaGraphDef,这可以破译。首先,您需要 org.tensorflow.proto 依赖项。然后,您可以像这样从元图中提取信息:

final MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());
final SignatureDef signatureDef = metaGraphDef.getSignatureDefMap().get("serving_default");

final TensorInfo inputTensorInfo = signatureDef.getInputsMap()
    .values()
    .stream()
    .filter(Objects::nonNull)
    .findFirst()
    .orElseThrow(() -> ...);

final TensorInfo outputTensorInfo = signatureDef.getOutputsMap()
    .values()
    .stream()
    .filter(Objects::nonNull)
    .findFirst()
    .orElseThrow(() -> ...);

现在您可以将您创建的张量输入到来自 inputTensorInfo.getName() 的名称中,并从 outputTensorInfo.getName() 中获取结果。