为 multi class classification 构建 tflite 模型
Building a tflite model for multi class classification
我阅读了多个代码实验室,其中 Google class 处理属于一个 class 的图像。如果我需要使用 2 个或更多 class 怎么办?例如,如果我想 class 确定图像是否包含水果或蔬菜,然后 class 确定它是哪种类型的水果或蔬菜。
您可以使用 TensorFlow(特别是使用 Keras)轻松训练卷积神经网络 (CNN)。互联网上有大量示例。参见 here and here。
接下来,我们使用 tf.lite.TFLiteConverter
、
将 Keras 保存的模型(.h5
文件)转换为 .tflite
文件
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
参见 here。
现在,在 Android 中,我们拍摄 Bitmap
图像并将其转换为 float[][][][]
、
private float[][][][] convertImageToFloatArray ( Bitmap image ) {
float[][][][] imageArray = new float[1][modelInputDim][modelInputDim][1] ;
for ( int x = 0 ; x < modelInputDim ; x ++ ) {
for ( int y = 0 ; y < modelInputDim ; y ++ ) {
float R = ( float )Color.red( image.getPixel( x , y ) );
float G = ( float )Color.green( image.getPixel( x , y ) );
float B = ( float )Color.blue( image.getPixel( x , y ) );
double grayscalePixel = (( 0.3 * R ) + ( 0.59 * G ) + ( 0.11 * B )) / 255;
imageArray[0][x][y][0] = (float)grayscalePixel ;
}
}
return imageArray ;
}
其中 modelInputDim
是模型的图像输入大小。上面的代码片段将 RGB 图像转换为灰度图像。
现在,我们进行最后的推理,
private int modelInputDim = 28 ;
private int outputDim = 3 ;
private float[] performInference(Bitmap frame , RectF cropImageRectF ) {
Bitmap croppedBitmap = getCroppedBitmap( frame , cropImageRectF ) ;
Bitmap croppedFrame = resizeBitmap( croppedBitmap );
float[][][][] imageArray = convertImageToFloatArray( croppedFrame ) ;
float[][] outputArray = new float[1][outputDim] ;
interpreter.run( imageArray , outputArray ) ;
return outputArray[0] ;
}
我准备了 Android 个应用程序集,这些应用程序使用 Android 中的 TFLite 模型。参见 here。
我阅读了多个代码实验室,其中 Google class 处理属于一个 class 的图像。如果我需要使用 2 个或更多 class 怎么办?例如,如果我想 class 确定图像是否包含水果或蔬菜,然后 class 确定它是哪种类型的水果或蔬菜。
您可以使用 TensorFlow(特别是使用 Keras)轻松训练卷积神经网络 (CNN)。互联网上有大量示例。参见 here and here。
接下来,我们使用 tf.lite.TFLiteConverter
、
.h5
文件)转换为 .tflite
文件
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
参见 here。
现在,在 Android 中,我们拍摄 Bitmap
图像并将其转换为 float[][][][]
、
private float[][][][] convertImageToFloatArray ( Bitmap image ) {
float[][][][] imageArray = new float[1][modelInputDim][modelInputDim][1] ;
for ( int x = 0 ; x < modelInputDim ; x ++ ) {
for ( int y = 0 ; y < modelInputDim ; y ++ ) {
float R = ( float )Color.red( image.getPixel( x , y ) );
float G = ( float )Color.green( image.getPixel( x , y ) );
float B = ( float )Color.blue( image.getPixel( x , y ) );
double grayscalePixel = (( 0.3 * R ) + ( 0.59 * G ) + ( 0.11 * B )) / 255;
imageArray[0][x][y][0] = (float)grayscalePixel ;
}
}
return imageArray ;
}
其中 modelInputDim
是模型的图像输入大小。上面的代码片段将 RGB 图像转换为灰度图像。
现在,我们进行最后的推理,
private int modelInputDim = 28 ;
private int outputDim = 3 ;
private float[] performInference(Bitmap frame , RectF cropImageRectF ) {
Bitmap croppedBitmap = getCroppedBitmap( frame , cropImageRectF ) ;
Bitmap croppedFrame = resizeBitmap( croppedBitmap );
float[][][][] imageArray = convertImageToFloatArray( croppedFrame ) ;
float[][] outputArray = new float[1][outputDim] ;
interpreter.run( imageArray , outputArray ) ;
return outputArray[0] ;
}
我准备了 Android 个应用程序集,这些应用程序使用 Android 中的 TFLite 模型。参见 here。