如何知道 Tensorflow.js 中的模型预测是什么?
How to know what are the model predictions in Tensorflow.js?
我正在使用通过 MobileNetV2 的迁移学习构建的自定义模型来检测人类意识。
我训练了模型,它在测试数据上达到了 99% 的分类准确率。
有3个类即0、5、10代表认知水平
我的模型returns在 React Native 中对来自 TensorCamera 的输入图像进行了以下预测,但我对预测在这个 Tensor 对象中的实际位置感到困惑
const prediction = model.predict(tensor.reshape([1,224,224,3]));
预测输出:
Tensor {
"dataId": Object {},
"dtype": "float32",
"id": 160213,
"isDisposedInternal": false,
"kept": false,
"rankType": "2",
"scopeId": 365032,
"shape": Array [
1,
3,
],
"size": 3,
"strides": Array [
3,
],
}
处理相机输入并每 3 帧进行预测的函数
const handleCameraStream = imageAsTensors => {
const verbose = true;
console.log("Tensor input 1");
try {
tf.print(imageAsTensors, verbose);
} catch (e) {
console.log("Tensor 1 not found!");
}
const loop = async () => {
if (loadedModel !== null) {
if (frameCount % makePredictionsEveryNFrames === 0) {
const imageTensor = imageAsTensors.next().value;
console.log("Tensor input 2");
tf.print(imageTensor, verbose);
await getPrediction(imageTensor).catch(e => console.log(e));
}
}
frameCount += 1;
frameCount = frameCount % makePredictionsEveryNFrames;
requestAnimationFrameId = requestAnimationFrame(loop);
};
//loop infinitely to constantly make predictions
loop();
};
获取预测的函数
const getPrediction = async tensor => {
if (!tensor) {
console.log("Tensor not found!");
return;
}
const model = await loadedModel;
const prediction = model.predict(tensor.reshape([1, 224, 224, 3]));
if (!prediction || prediction.length === 0) {
console.log("No prediction available");
return;
}
console.log(prediction);
//console.log(`Predictions: ${JSON.stringify(prediction)}`);
// Only take the predictions with a probability of 30% and greater
//Stop looping
cancelAnimationFrame(requestAnimationFrameId);
//setPredictionFound(true);
//setModelPrediction(prediction[0].className);
tensor.dispose();
};
我需要使用 dataSync() 来获取预测结果
const preds = prediction.dataSync();
preds.forEach((pred, i) => {
//console.log(`x: ${i}, pred: ${pred}`);
if (pred > 0.8) {
console.log(`x: ${i}, pred: ${pred}`);
setModelPrediction(prediction: pred, class: i)
}
x: 0, pred: 0.9000627994537354
x: 1, pred: 0.023466499522328377
x: 2, pred: 0.0764707624912262
我正在使用通过 MobileNetV2 的迁移学习构建的自定义模型来检测人类意识。
我训练了模型,它在测试数据上达到了 99% 的分类准确率。
有3个类即0、5、10代表认知水平
我的模型returns在 React Native 中对来自 TensorCamera 的输入图像进行了以下预测,但我对预测在这个 Tensor 对象中的实际位置感到困惑
const prediction = model.predict(tensor.reshape([1,224,224,3]));
预测输出:
Tensor {
"dataId": Object {},
"dtype": "float32",
"id": 160213,
"isDisposedInternal": false,
"kept": false,
"rankType": "2",
"scopeId": 365032,
"shape": Array [
1,
3,
],
"size": 3,
"strides": Array [
3,
],
}
处理相机输入并每 3 帧进行预测的函数
const handleCameraStream = imageAsTensors => {
const verbose = true;
console.log("Tensor input 1");
try {
tf.print(imageAsTensors, verbose);
} catch (e) {
console.log("Tensor 1 not found!");
}
const loop = async () => {
if (loadedModel !== null) {
if (frameCount % makePredictionsEveryNFrames === 0) {
const imageTensor = imageAsTensors.next().value;
console.log("Tensor input 2");
tf.print(imageTensor, verbose);
await getPrediction(imageTensor).catch(e => console.log(e));
}
}
frameCount += 1;
frameCount = frameCount % makePredictionsEveryNFrames;
requestAnimationFrameId = requestAnimationFrame(loop);
};
//loop infinitely to constantly make predictions
loop();
};
获取预测的函数
const getPrediction = async tensor => {
if (!tensor) {
console.log("Tensor not found!");
return;
}
const model = await loadedModel;
const prediction = model.predict(tensor.reshape([1, 224, 224, 3]));
if (!prediction || prediction.length === 0) {
console.log("No prediction available");
return;
}
console.log(prediction);
//console.log(`Predictions: ${JSON.stringify(prediction)}`);
// Only take the predictions with a probability of 30% and greater
//Stop looping
cancelAnimationFrame(requestAnimationFrameId);
//setPredictionFound(true);
//setModelPrediction(prediction[0].className);
tensor.dispose();
};
我需要使用 dataSync() 来获取预测结果
const preds = prediction.dataSync();
preds.forEach((pred, i) => {
//console.log(`x: ${i}, pred: ${pred}`);
if (pred > 0.8) {
console.log(`x: ${i}, pred: ${pred}`);
setModelPrediction(prediction: pred, class: i)
}
x: 0, pred: 0.9000627994537354
x: 1, pred: 0.023466499522328377
x: 2, pred: 0.0764707624912262