使用自定义 TFLITE 的 Firebase ML Kit 对 Android 上的不同输出产生相同的推断
Firebase ML Kit using custom TFLITE produces the same inference for varied outputs on Android
我正在研究一个音频分类模型,该模型根据音频的流派对音频进行分类。
该模型采用一些音频特征,如频谱质心等,并产生 classical/rock/etc 等输出。输入形状 -> [1,26]
这是一个多标签分类器。
我有一个 Keras 模型,我已将其转换为 TFLite 模型以便在移动平台上使用。我已经测试了初始模型,它的工作精度相当不错,当 运行 和 Python 在我的电脑上工作时,tflite 模型也一样。
当我将其部署到 Firebase 的 ML Kit 并将其与 Android API 一起使用时,它会生成一个 label/class 作为各种输入的输出。我不认为这是模型的问题,因为它在我的 Jupyter 笔记本中运行良好。
我不明白它如何为相同的输入产生不同的推理?
Keras 模型:
#The test model
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Dropout, Activation
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))
model.add(Dropout(0.5))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
history = model.fit(X_train,
y_train,
epochs=10)
#print(X_test[:1],y_test)
pred = model.predict_classes(X_test)
print(pred)
print(y_test)
转化码:
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
Input/Output 形状:
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Print input shape and type
print(interpreter.get_input_details()[0]['shape']) # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype']) # Example: <class 'numpy.float32'>
# Print output shape and type
print(interpreter.get_output_details()[0]['shape']) # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype']) # Example: <class 'numpy.float32'>
[ 1 26]
<class 'numpy.float32'>
[ 1 10]
<class 'numpy.float32'>
用于测试的演示 Kotlin 代码:
listenButton.setOnClickListener {
incorrecttagButton.alpha = 1f
incorrecttagButton.isClickable = true
//Code for listening to music
FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
.addOnSuccessListener { isDownloaded ->
val options =
if (isDownloaded) {
FirebaseModelInterpreterOptions.Builder(remoteModel).build()
} else {
FirebaseModelInterpreterOptions.Builder(localModel).build()
}
Log.d("HUSKY","Downloaded? ${isDownloaded}")
val interpreter = FirebaseModelInterpreter.getInstance(options)
val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 26))
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1,10))
.build()
if(songNum==5){
songNum=0
}
val testSong = testsongs[songNum]
Log.d("HUSKY", "Song num = ${songNum} F = ${testSong} ")
val input = Array(1){FloatArray(26)}
val itr = testSong.split(",").toTypedArray()
val preInput = itr.map { it.toFloat() }
var x = 0
preInput.forEach {
input[0][x] = preInput[x]
x+=1
}
//val input = preInput.toTypedArray()
Log.d("HUSKY", "${input[0][1]}")
val inputs = FirebaseModelInputs.Builder()
.add(input) // add() as many input arrays as your model requires
.build()
val labelArray = "blues classical country disco hiphop jazz metal pop reggae rock".split(" ").toTypedArray()
Log.d("HUSKY2", "GG")
interpreter?.run(inputs, inputOutputOptions)?.addOnSuccessListener { result ->
Log.d("HUSKY2", "GGWP")
val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]
var bestMatch = 0f
var bestMatchIndex = 0
for (i in probabilities.indices){
if(probabilities[i]>bestMatch){
bestMatch = probabilities[i]
bestMatchIndex = i
}
Log.d("HUSKY2", "${labelArray[i]} ${probabilities[i]}")
genreLabel.text = labelArray[i]
}
genreLabel.text = labelArray[bestMatchIndex].capitalize()
confidenceLabel.text = probabilities[bestMatchIndex].toString()
// ...
}?.addOnFailureListener { e ->
// Task failed with an exception
// ...
Log.d("HUSKY2", "GGWP :( ${e.toString()}")
}
}
我正在使用 SongNum 增加字符串数组来更改歌曲。这些特征存储为以逗号作为分隔符的字符串。
无论输入特征如何(SongNum 变量更改歌曲 [0-4]),输出如下且相同,并且流行的置信度始终为 1.0 :
2020-02-25 00:11:21.014 17434-17434/com.rohanbojja.audient D/HUSKY: Downloaded? true
2020-02-25 00:11:21.015 17434-17434/com.rohanbojja.audient D/HUSKY: Song num = 0 F = 0.3595172803692916,0.04380025714635849,1365.710742222286,1643.935571084307,2725.445556640625,0.06513807508680555,-273.0061247040518,132.66331747988934,-31.86709317807114,44.21442952318603,4.335704872427025,32.32360339344842,-2.4662076330637714,20.458242724823684,-4.760171779927926,20.413702740993585,3.69545905318442,8.581128171784677,-15.601809275025104,5.295758930950924,-5.270195074271744,5.895109210872318,-6.1406603018722645,-2.9278519508415286,-1.9189588023091468,5.954495267889836
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY: 0.043800257
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY2: GG
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: GGWP
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: blues 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: classical 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: country 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: disco 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: hiphop 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: jazz 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: metal 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: pop 1.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: reggae 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: rock 0.0
Jupyter Notebook 上的输出如下:
(blues,) (classical,) (country,) (disco,) (hiphop,) (jazz,) (metal,) (pop,) (reggae,) (rock,)
0 0.257037 0.000705 0.429687 0.030933 0.009291 0.004909 1.734001e-03 0.000912 0.203305 0.061488
根据我的结论,我搞砸了 ML 套件的使用 API?或者我传递输入数据或检索输出数据的方式?我是 android 开发的新手。
输出:
'pop' 始终有 1.0 的信心!
预期输出:
每个流派都应该在 [0-1.0] 之间有一定的信心,而不是 'pop' 总是,就像我从 Jupyter notebook 得到的结果。
抱歉代码乱七八糟。
如有任何帮助,我们将不胜感激!
更新 1:我将 relu 与 sigmoid 激活函数交换,我可以注意到其中的区别。它仍然几乎总是 "pop",但有大约 0.30 的置信度。现在超级神秘。只发生在 ML Kit 顺便说一句,还没有真正尝试过在本地实现它。
更新 2:我不明白如何使用相同的模型得出不同的推论。我迷路了。
一旦在预测阶段提取了特征,我就没有对其进行归一化,即提取的特征未进行转换。
我已经用
转换了训练数据
X = StandardScaler().fit_transform(np.array(data.iloc[:,1:-1]))
为了解决这个问题,我不得不转换特征:
scaler=StandardScaler().fit(np.array(data.iloc[:,1:-1]))
input_data = scaler.transform(input_data2)
我正在研究一个音频分类模型,该模型根据音频的流派对音频进行分类。
该模型采用一些音频特征,如频谱质心等,并产生 classical/rock/etc 等输出。输入形状 -> [1,26] 这是一个多标签分类器。 我有一个 Keras 模型,我已将其转换为 TFLite 模型以便在移动平台上使用。我已经测试了初始模型,它的工作精度相当不错,当 运行 和 Python 在我的电脑上工作时,tflite 模型也一样。
当我将其部署到 Firebase 的 ML Kit 并将其与 Android API 一起使用时,它会生成一个 label/class 作为各种输入的输出。我不认为这是模型的问题,因为它在我的 Jupyter 笔记本中运行良好。 我不明白它如何为相同的输入产生不同的推理?
Keras 模型:
#The test model
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Dropout, Activation
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))
model.add(Dropout(0.5))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
history = model.fit(X_train,
y_train,
epochs=10)
#print(X_test[:1],y_test)
pred = model.predict_classes(X_test)
print(pred)
print(y_test)
转化码:
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
Input/Output 形状:
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Print input shape and type
print(interpreter.get_input_details()[0]['shape']) # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype']) # Example: <class 'numpy.float32'>
# Print output shape and type
print(interpreter.get_output_details()[0]['shape']) # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype']) # Example: <class 'numpy.float32'>
[ 1 26]
<class 'numpy.float32'>
[ 1 10]
<class 'numpy.float32'>
用于测试的演示 Kotlin 代码:
listenButton.setOnClickListener {
incorrecttagButton.alpha = 1f
incorrecttagButton.isClickable = true
//Code for listening to music
FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
.addOnSuccessListener { isDownloaded ->
val options =
if (isDownloaded) {
FirebaseModelInterpreterOptions.Builder(remoteModel).build()
} else {
FirebaseModelInterpreterOptions.Builder(localModel).build()
}
Log.d("HUSKY","Downloaded? ${isDownloaded}")
val interpreter = FirebaseModelInterpreter.getInstance(options)
val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 26))
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1,10))
.build()
if(songNum==5){
songNum=0
}
val testSong = testsongs[songNum]
Log.d("HUSKY", "Song num = ${songNum} F = ${testSong} ")
val input = Array(1){FloatArray(26)}
val itr = testSong.split(",").toTypedArray()
val preInput = itr.map { it.toFloat() }
var x = 0
preInput.forEach {
input[0][x] = preInput[x]
x+=1
}
//val input = preInput.toTypedArray()
Log.d("HUSKY", "${input[0][1]}")
val inputs = FirebaseModelInputs.Builder()
.add(input) // add() as many input arrays as your model requires
.build()
val labelArray = "blues classical country disco hiphop jazz metal pop reggae rock".split(" ").toTypedArray()
Log.d("HUSKY2", "GG")
interpreter?.run(inputs, inputOutputOptions)?.addOnSuccessListener { result ->
Log.d("HUSKY2", "GGWP")
val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]
var bestMatch = 0f
var bestMatchIndex = 0
for (i in probabilities.indices){
if(probabilities[i]>bestMatch){
bestMatch = probabilities[i]
bestMatchIndex = i
}
Log.d("HUSKY2", "${labelArray[i]} ${probabilities[i]}")
genreLabel.text = labelArray[i]
}
genreLabel.text = labelArray[bestMatchIndex].capitalize()
confidenceLabel.text = probabilities[bestMatchIndex].toString()
// ...
}?.addOnFailureListener { e ->
// Task failed with an exception
// ...
Log.d("HUSKY2", "GGWP :( ${e.toString()}")
}
}
我正在使用 SongNum 增加字符串数组来更改歌曲。这些特征存储为以逗号作为分隔符的字符串。
无论输入特征如何(SongNum 变量更改歌曲 [0-4]),输出如下且相同,并且流行的置信度始终为 1.0 :
2020-02-25 00:11:21.014 17434-17434/com.rohanbojja.audient D/HUSKY: Downloaded? true
2020-02-25 00:11:21.015 17434-17434/com.rohanbojja.audient D/HUSKY: Song num = 0 F = 0.3595172803692916,0.04380025714635849,1365.710742222286,1643.935571084307,2725.445556640625,0.06513807508680555,-273.0061247040518,132.66331747988934,-31.86709317807114,44.21442952318603,4.335704872427025,32.32360339344842,-2.4662076330637714,20.458242724823684,-4.760171779927926,20.413702740993585,3.69545905318442,8.581128171784677,-15.601809275025104,5.295758930950924,-5.270195074271744,5.895109210872318,-6.1406603018722645,-2.9278519508415286,-1.9189588023091468,5.954495267889836
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY: 0.043800257
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY2: GG
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: GGWP
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: blues 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: classical 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: country 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: disco 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: hiphop 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: jazz 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: metal 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: pop 1.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: reggae 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: rock 0.0
Jupyter Notebook 上的输出如下:
(blues,) (classical,) (country,) (disco,) (hiphop,) (jazz,) (metal,) (pop,) (reggae,) (rock,)
0 0.257037 0.000705 0.429687 0.030933 0.009291 0.004909 1.734001e-03 0.000912 0.203305 0.061488
根据我的结论,我搞砸了 ML 套件的使用 API?或者我传递输入数据或检索输出数据的方式?我是 android 开发的新手。
输出: 'pop' 始终有 1.0 的信心! 预期输出: 每个流派都应该在 [0-1.0] 之间有一定的信心,而不是 'pop' 总是,就像我从 Jupyter notebook 得到的结果。
抱歉代码乱七八糟。
如有任何帮助,我们将不胜感激!
更新 1:我将 relu 与 sigmoid 激活函数交换,我可以注意到其中的区别。它仍然几乎总是 "pop",但有大约 0.30 的置信度。现在超级神秘。只发生在 ML Kit 顺便说一句,还没有真正尝试过在本地实现它。
更新 2:我不明白如何使用相同的模型得出不同的推论。我迷路了。
一旦在预测阶段提取了特征,我就没有对其进行归一化,即提取的特征未进行转换。
我已经用
转换了训练数据X = StandardScaler().fit_transform(np.array(data.iloc[:,1:-1]))
为了解决这个问题,我不得不转换特征:
scaler=StandardScaler().fit(np.array(data.iloc[:,1:-1]))
input_data = scaler.transform(input_data2)