检查目标时出错:预期 dense_Dense2 具有形状 [5],但得到形状为 [5,1] 的数组

Error when checking target: expected dense_Dense2 to have shape [,5], but got array with shape [5,1]

我在 Tensorflow website and I can't make the section which does model.fit 上做图像分类教程。

这是错误。

Error: Error when checking target: expected dense_Dense2 to have shape [,5], but got array with shape [5,1].

似乎对我通过的模型不满意。

我尝试传递 tf.tensor3D[] 结构,但它也不喜欢它:

Error: Error when checking model input: the Array of Tensors that you are passing to your model is not the size the model expected. Expected to see 1 Tensor(s), but instead got the following list of Tensor(s): Tensor

至于现在 [num_of_images, img_height, img_width, num_of_color_channels]

干杯!

const button = document.getElementById("random");
const randomPokemon = document.getElementById("randomPokemon");
const prediction = document.getElementById("prediction");
const imageWidth = 128;
const imageHeight = 128;

const createModel = (classes) => {
    const model = tf.sequential();

    model.add(tf.layers.conv2d({
        inputShape: [imageHeight, imageWidth, 3],
        kernelSize: [3, 3],
        filters: 16,
        padding: "same",
        activation: 'relu'
    }));
    model.add(tf.layers.maxPooling2d({}));
    model.add(tf.layers.conv2d({
        kernelSize: [3, 3],
        filters: 32,
        padding: "same",
        activation: 'relu'
    }));
    model.add(tf.layers.maxPooling2d({}));
    model.add(tf.layers.conv2d({
        kernelSize: [3, 3],
        filters: 64,
        padding: "same",
        activation: 'relu'
    }));
    model.add(tf.layers.maxPooling2d({}));
    model.add(tf.layers.flatten({}));
    model.add(tf.layers.dense({units: 128, activation: 'relu'}));
    model.add(tf.layers.dense({ units: classes }));
    model.compile({
        optimizer: "adam",
        loss: "categoricalCrossentropy",
        metrics: ['accuracy'],
    });

    return model;
}

const getTensorFromImage = (src) => new Promise((resolve, reject) => {
    const image = new Image();
    
    image.onload = () => {
        const canvas = document.createElement('canvas');
        const context = canvas.getContext('2d');
        canvas.width = imageWidth;
        canvas.height = imageHeight;
        image.width = image.naturalWidth;
        image.height = image.naturalHeight;

        context.drawImage(image, 0, 0, imageWidth, imageHeight);
        
        const tensor = tf.browser.fromPixels(canvas);

        resolve(tensor);
    };
    image.crossOrigin = "";
    image.referrerPolicy = "origin"
    image.src = src;
});

const pokemons = [
    "wartortle.png",
    "bulbasaur.png",
    "charmander.png",
    "blastoise.png",
    "kakuna.png",
  ];
  
const getRandomPokemon = () => {
  const index = Math.floor(Math.random() * pokemons.length);
  const pokemon = pokemons[index];

  return `https://cors-anywhere.herokuapp.com/https://img.pokemondb.net/sprites/bank/normal/${pokemon}`
}

const model = createModel(pokemons.length);

(async () => {

const tensor_promises = pokemons.map(pr => getTensorFromImage(`https://cors-anywhere.herokuapp.com/https://img.pokemondb.net/sprites/bank/normal/${pr}`))
  const tensors = await Promise.all(tensor_promises);
  labels_normalized = pokemons.map((pr, index) => index);
  const label_tensor = tf.tensor1d(labels_normalized);
  
  const computedData = tensors.map(pr => pr.arraySync());
  const tensor4d = tf.tensor4d(computedData, [computedData.length, ...tensors[0].shape])

  await model.fit(tensor4d, label_tensor, {
    epochs: 20,
  })
})();

const onClick = async () => {
  const pokemonSrc = getRandomPokemon();
  randomPokemon.src = pokemonSrc;
  const tensor = await getTensorFromImage(pokemonSrc);
  
  const model_prediction = model.predict([tensor]);
  const predictions = model_prediction[0];
  const pokemon = pokemonss[predictions[0]];
  
  prediction.textContent = `Prediction: ${pokemon}`;
}

button.addEventListener("click", onClick);
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/2.7.0/tf.min.js"></script>

<img id="randomPokemon"/>
<span id="prediction"></span>
<div>
  <button id="random">Random Pokemon</button>
</div>

用于预测的张量应该是4d张量。预期的形状可以看作是输入形状的数组。所以这里是一个 3d 张量数组,因此是一个 4d 张量。如果你只有一个图像(3d张量),你可以考虑扩大第一个轴(.expandDims(0))。这样,就好像你有一个图像数组。