在 MainActivity.java 中使用非量化 tflite 文件所需的更改
Changes required for using non-quantized tflite files in MainActivity.java
此 MainActivity.java 是为量化模型编写的,我正在尝试使用非量化模型。
按照此处、此处进行更改后MainActivity.java,我的代码是
public class MainActivity extends AppCompatActivity implements AdapterView.OnItemSelectedListener {
private static final String TAG = "MainActivity";
private Button mRun;
private ImageView mImageView;
private Bitmap mSelectedImage;
private GraphicOverlay mGraphicOverlay;
// Max width (portrait mode)
private Integer mImageMaxWidth;
// Max height (portrait mode)
private Integer mImageMaxHeight;
private final String[] mFilePaths =
new String[]{"mountain.jpg", "tennis.jpg","96580.jpg"};
/**
* Name of the model file hosted with Firebase.
*/
private static final String HOSTED_MODEL_NAME = "mobilenet_v1_224_quant";
private static final String LOCAL_MODEL_ASSET = "retrained_graph_mobilenet_1_224.tflite";
/**
* Name of the label file stored in Assets.
*/
private static final String LABEL_PATH = "labels.txt";
/**
* Number of results to show in the UI.
*/
private static final int RESULTS_TO_SHOW = 3;
/**
* Dimensions of inputs.
*/
private static final int DIM_BATCH_SIZE = 1;
private static final int DIM_PIXEL_SIZE = 3;
private static final int DIM_IMG_SIZE_X = 224;
private static final int DIM_IMG_SIZE_Y = 224;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
/**
* Labels corresponding to the output of the vision model.
*/
private List<String> mLabelList;
private final PriorityQueue<Map.Entry<String, Float>> sortedLabels =
new PriorityQueue<>(
RESULTS_TO_SHOW,
new Comparator<Map.Entry<String, Float>>() {
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float>
o2) {
return (o1.getValue()).compareTo(o2.getValue());
}
});
/* Preallocated buffers for storing image data. */
private final int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
/**
* An instance of the driver class to run model inference with Firebase.
*/
private FirebaseModelInterpreter mInterpreter;
/**
* Data configuration of input & output data of model.
*/
private FirebaseModelInputOutputOptions mDataOptions;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mGraphicOverlay = findViewById(R.id.graphic_overlay);
mImageView = findViewById(R.id.image_view);
Spinner dropdown = findViewById(R.id.spinner);
List<String> items = new ArrayList<>();
for (int i = 0; i < mFilePaths.length; i++) {
items.add("Image " + (i + 1));
}
ArrayAdapter<String> adapter = new ArrayAdapter<>(this, android.R.layout
.simple_spinner_dropdown_item, items);
dropdown.setAdapter(adapter);
dropdown.setOnItemSelectedListener(this);
mLabelList = loadLabelList(this);
mRun = findViewById(R.id.button_run);
mRun.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
runModelInference();
}
});
int[] inputDims = {DIM_BATCH_SIZE, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y, DIM_PIXEL_SIZE};
int[] outputDims = {DIM_BATCH_SIZE, mLabelList.size()};
try {
mDataOptions =
new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, inputDims)
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, outputDims)
.build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions
.Builder()
.requireWifi()
.build();
FirebaseLocalModelSource localModelSource =
new FirebaseLocalModelSource.Builder("asset")
.setAssetFilePath(LOCAL_MODEL_ASSET).build();
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder
(HOSTED_MODEL_NAME)
.enableModelUpdates(true)
.setInitialDownloadConditions(conditions)
.setUpdatesDownloadConditions(conditions) // You could also specify
// different conditions
// for updates
.build();
FirebaseModelManager manager = FirebaseModelManager.getInstance();
manager.registerLocalModelSource(localModelSource);
manager.registerCloudModelSource(cloudSource);
FirebaseModelOptions modelOptions =
new FirebaseModelOptions.Builder()
.setCloudModelName(HOSTED_MODEL_NAME)
.setLocalModelName("asset")
.build();
mInterpreter = FirebaseModelInterpreter.getInstance(modelOptions);
} catch (FirebaseMLException e) {
showToast("Error while setting up the model");
e.printStackTrace();
}
}
private void runModelInference() {
if (mInterpreter == null) {
Log.e(TAG, "Image classifier has not been initialized; Skipped.");
return;
}
// Create input data.
ByteBuffer imgData = convertBitmapToByteBuffer(mSelectedImage, mSelectedImage.getWidth(),
mSelectedImage.getHeight());
try {
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder().add(imgData).build();
// Here's where the magic happens!!
mInterpreter
.run(inputs, mDataOptions)
.addOnFailureListener(new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
e.printStackTrace();
showToast("Error running model inference");
}
})
.continueWith(
new Continuation<FirebaseModelOutputs, List<String>>() {
@Override
public List<String> then(Task<FirebaseModelOutputs> task) {
float[][] labelProbArray = task.getResult()
.<float[][]>getOutput(0);
List<String> topLabels = getTopLabels(labelProbArray);
mGraphicOverlay.clear();
GraphicOverlay.Graphic labelGraphic = new LabelGraphic
(mGraphicOverlay, topLabels);
mGraphicOverlay.add(labelGraphic);
return topLabels;
}
});
} catch (FirebaseMLException e) {
e.printStackTrace();
showToast("Error running model inference");
}
}
/**
* Gets the top labels in the results.
*/
private synchronized List<String> getTopLabels(float[][] labelProbArray) {
for (int i = 0; i < mLabelList.size(); ++i) {
sortedLabels.add(
new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] )));
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
}
List<String> result = new ArrayList<>();
final int size = sortedLabels.size();
for (int i = 0; i < size; ++i) {
Map.Entry<String, Float> label = sortedLabels.poll();
result.add(label.getKey() + ":" + label.getValue());
}
Log.d(TAG, "labels: " + result.toString());
return result;
}
/**
* Reads label list from Assets.
*/
private List<String> loadLabelList(Activity activity) {
List<String> labelList = new ArrayList<>();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(activity.getAssets().open
(LABEL_PATH)))) {
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
} catch (IOException e) {
Log.e(TAG, "Failed to read label list.", e);
}
return labelList;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
private synchronized ByteBuffer convertBitmapToByteBuffer(
Bitmap bitmap, int width, int height) {
ByteBuffer imgData =
ByteBuffer.allocateDirect(
4*DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y,
true);
imgData.rewind();
scaledBitmap.getPixels(intValues, 0, scaledBitmap.getWidth(), 0, 0,
scaledBitmap.getWidth(), scaledBitmap.getHeight());
// Convert the image to int points.
int pixel = 0;
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat(((val & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
return imgData;
}
private void showToast(String message) {
Toast.makeText(getApplicationContext(), message, Toast.LENGTH_SHORT).show();
}
public void onItemSelected(AdapterView<?> parent, View v, int position, long id) {
mGraphicOverlay.clear();
mSelectedImage = getBitmapFromAsset(this, mFilePaths[position]);
if (mSelectedImage != null) {
// Get the dimensions of the View
Pair<Integer, Integer> targetedSize = getTargetedWidthHeight();
int targetWidth = targetedSize.first;
int maxHeight = targetedSize.second;
// Determine how much to scale down the image
float scaleFactor =
Math.max(
(float) mSelectedImage.getWidth() / (float) targetWidth,
(float) mSelectedImage.getHeight() / (float) maxHeight);
Bitmap resizedBitmap =
Bitmap.createScaledBitmap(
mSelectedImage,
(int) (mSelectedImage.getWidth() / scaleFactor),
(int) (mSelectedImage.getHeight() / scaleFactor),
true);
mImageView.setImageBitmap(resizedBitmap);
mSelectedImage = resizedBitmap;
}
}
@Override
public void onNothingSelected(AdapterView<?> parent) {
// Do nothing
}
// Utility functions for loading and resizing images from app asset folder.
public static Bitmap getBitmapFromAsset(Context context, String filePath) {
AssetManager assetManager = context.getAssets();
InputStream is;
Bitmap bitmap = null;
try {
is = assetManager.open(filePath);
bitmap = BitmapFactory.decodeStream(is);
} catch (IOException e) {
e.printStackTrace();
}
return bitmap;
}
// Returns max image width, always for portrait mode. Caller needs to swap width / height for
// landscape mode.
private Integer getImageMaxWidth() {
if (mImageMaxWidth == null) {
// Calculate the max width in portrait mode. This is done lazily since we need to
// wait for a UI layout pass to get the right values. So delay it to first time image
// rendering time.
mImageMaxWidth = mImageView.getWidth();
}
return mImageMaxWidth;
}
// Returns max image height, always for portrait mode. Caller needs to swap width / height for
// landscape mode.
private Integer getImageMaxHeight() {
if (mImageMaxHeight == null) {
// Calculate the max width in portrait mode. This is done lazily since we need to
// wait for a UI layout pass to get the right values. So delay it to first time image
// rendering time.
mImageMaxHeight =
mImageView.getHeight();
}
return mImageMaxHeight;
}
// Gets the targeted width / height.
private Pair<Integer, Integer> getTargetedWidthHeight() {
int targetWidth;
int targetHeight;
int maxWidthForPortraitMode = getImageMaxWidth();
int maxHeightForPortraitMode = getImageMaxHeight();
targetWidth = maxWidthForPortraitMode;
targetHeight = maxHeightForPortraitMode;
return new Pair<>(targetWidth, targetHeight);
}
}
但我仍然得到 Failed to get input dimensions. 0-th input should have 268203 bytes, but found 1072812 bytes
的 inception 和 0-th input should have 150528 bytes, but found 602112 bytes
的 mobilenet。所以,一个因素是 4
总是存在的。
要查看我所做的更改,diff original.java changed.java
的输出是:(忽略行号)
32a33,34
> private static final int IMAGE_MEAN = 128;
> private static final float IMAGE_STD = 128.0f;
150,151c152,153
< byte[][] labelProbArray = task.getResult()
< .<byte[][]>getOutput(0);
---
> float[][] labelProbArray = task.getResult()
> .<float[][]>getOutput(0);
170c172
< private synchronized List<String> getTopLabels(byte[][] labelProbArray) {
---
> private synchronized List<String> getTopLabels(float[][] labelProbArray) {
173,174c175
< new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] &
< 0xff) / 255.0f));
---
> new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] )));
214c215,216
< DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
---
> 4*DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
>
226,228c228,232
< imgData.put((byte) ((val >> 16) & 0xFF));
< imgData.put((byte) ((val >> 8) & 0xFF));
< imgData.put((byte) (val & 0xFF));
---
> imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
> imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
> imgData.putFloat(((val & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
这是代码实验室中缓冲区的分配方式:
ByteBuffer imgData = ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
DIM_BATCH_SIZE - 典型的用法是支持批处理(如果模型支持的话)。在我们的示例和您的测试中,您一次输入 1 张图像并将其保持为 1。
DIM_PIXEL_SIZE - 我们在代码实验室中设置了 3,对应 r/g/b 每个 1 个字节。
但是,看起来您使用的是浮动模型。然后,您使用一个浮点数(4 个字节)来表示每个 r/g/b,而不是每个 r/g/b 一个字节(您已经自己想出了这部分)。那么您使用上述代码分配的缓冲区不再足够。
准确的说imgData人口,分配公式如下:
ByteBuffer imgData = ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * getImageSizeX() * getImageSizeY() * DIM_PIXEL_SIZE
* getNumBytesPerChannel());
getNumBytesPerChannel() 在你的情况下应该是 4。
[更新新问题,关于以下错误]:
获取输入维度失败。第 0 个输入应该有 268203 个字节,但找到了 1072812 个字节
这是检查模型预期的字节数 == 传入的字节数。268203 = 299 * 299 * 3 & 1072812 = 4 * 299 * 299 * 3。看起来你正在使用量化模型,但为它提供了浮动模型的数据。你能仔细检查一下你使用的模型吗?为简单起见,不要指定云模型源并仅使用资产中的本地模型。
[更新0628,开发者说他们训练了一个浮动模型]:
可能是你的模型有误;也可能是您下载了覆盖本地模型的云模型。但是错误消息告诉我们正在加载的模型不是浮动模型。
为了隔离问题,我建议进行以下几项测试:
1)从快速启动应用程序中删除 setCloudModelName / registerCloudModelSource
2) 玩 official TFLite float model 您必须下载评论中提到的模型并更改 Camera2BasicFragment 以使用该 ImageClassifierFloatInception(而不是 ImageClassifierQuantizedMobileNet)
3) 仍然使用相同的 TFLite 示例应用程序,切换到您自己训练的模型。确保将图像大小调整为您的值。
此 MainActivity.java 是为量化模型编写的,我正在尝试使用非量化模型。
按照此处、此处进行更改后MainActivity.java,我的代码是
public class MainActivity extends AppCompatActivity implements AdapterView.OnItemSelectedListener {
private static final String TAG = "MainActivity";
private Button mRun;
private ImageView mImageView;
private Bitmap mSelectedImage;
private GraphicOverlay mGraphicOverlay;
// Max width (portrait mode)
private Integer mImageMaxWidth;
// Max height (portrait mode)
private Integer mImageMaxHeight;
private final String[] mFilePaths =
new String[]{"mountain.jpg", "tennis.jpg","96580.jpg"};
/**
* Name of the model file hosted with Firebase.
*/
private static final String HOSTED_MODEL_NAME = "mobilenet_v1_224_quant";
private static final String LOCAL_MODEL_ASSET = "retrained_graph_mobilenet_1_224.tflite";
/**
* Name of the label file stored in Assets.
*/
private static final String LABEL_PATH = "labels.txt";
/**
* Number of results to show in the UI.
*/
private static final int RESULTS_TO_SHOW = 3;
/**
* Dimensions of inputs.
*/
private static final int DIM_BATCH_SIZE = 1;
private static final int DIM_PIXEL_SIZE = 3;
private static final int DIM_IMG_SIZE_X = 224;
private static final int DIM_IMG_SIZE_Y = 224;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
/**
* Labels corresponding to the output of the vision model.
*/
private List<String> mLabelList;
private final PriorityQueue<Map.Entry<String, Float>> sortedLabels =
new PriorityQueue<>(
RESULTS_TO_SHOW,
new Comparator<Map.Entry<String, Float>>() {
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float>
o2) {
return (o1.getValue()).compareTo(o2.getValue());
}
});
/* Preallocated buffers for storing image data. */
private final int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
/**
* An instance of the driver class to run model inference with Firebase.
*/
private FirebaseModelInterpreter mInterpreter;
/**
* Data configuration of input & output data of model.
*/
private FirebaseModelInputOutputOptions mDataOptions;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mGraphicOverlay = findViewById(R.id.graphic_overlay);
mImageView = findViewById(R.id.image_view);
Spinner dropdown = findViewById(R.id.spinner);
List<String> items = new ArrayList<>();
for (int i = 0; i < mFilePaths.length; i++) {
items.add("Image " + (i + 1));
}
ArrayAdapter<String> adapter = new ArrayAdapter<>(this, android.R.layout
.simple_spinner_dropdown_item, items);
dropdown.setAdapter(adapter);
dropdown.setOnItemSelectedListener(this);
mLabelList = loadLabelList(this);
mRun = findViewById(R.id.button_run);
mRun.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
runModelInference();
}
});
int[] inputDims = {DIM_BATCH_SIZE, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y, DIM_PIXEL_SIZE};
int[] outputDims = {DIM_BATCH_SIZE, mLabelList.size()};
try {
mDataOptions =
new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, inputDims)
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, outputDims)
.build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions
.Builder()
.requireWifi()
.build();
FirebaseLocalModelSource localModelSource =
new FirebaseLocalModelSource.Builder("asset")
.setAssetFilePath(LOCAL_MODEL_ASSET).build();
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder
(HOSTED_MODEL_NAME)
.enableModelUpdates(true)
.setInitialDownloadConditions(conditions)
.setUpdatesDownloadConditions(conditions) // You could also specify
// different conditions
// for updates
.build();
FirebaseModelManager manager = FirebaseModelManager.getInstance();
manager.registerLocalModelSource(localModelSource);
manager.registerCloudModelSource(cloudSource);
FirebaseModelOptions modelOptions =
new FirebaseModelOptions.Builder()
.setCloudModelName(HOSTED_MODEL_NAME)
.setLocalModelName("asset")
.build();
mInterpreter = FirebaseModelInterpreter.getInstance(modelOptions);
} catch (FirebaseMLException e) {
showToast("Error while setting up the model");
e.printStackTrace();
}
}
private void runModelInference() {
if (mInterpreter == null) {
Log.e(TAG, "Image classifier has not been initialized; Skipped.");
return;
}
// Create input data.
ByteBuffer imgData = convertBitmapToByteBuffer(mSelectedImage, mSelectedImage.getWidth(),
mSelectedImage.getHeight());
try {
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder().add(imgData).build();
// Here's where the magic happens!!
mInterpreter
.run(inputs, mDataOptions)
.addOnFailureListener(new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
e.printStackTrace();
showToast("Error running model inference");
}
})
.continueWith(
new Continuation<FirebaseModelOutputs, List<String>>() {
@Override
public List<String> then(Task<FirebaseModelOutputs> task) {
float[][] labelProbArray = task.getResult()
.<float[][]>getOutput(0);
List<String> topLabels = getTopLabels(labelProbArray);
mGraphicOverlay.clear();
GraphicOverlay.Graphic labelGraphic = new LabelGraphic
(mGraphicOverlay, topLabels);
mGraphicOverlay.add(labelGraphic);
return topLabels;
}
});
} catch (FirebaseMLException e) {
e.printStackTrace();
showToast("Error running model inference");
}
}
/**
* Gets the top labels in the results.
*/
private synchronized List<String> getTopLabels(float[][] labelProbArray) {
for (int i = 0; i < mLabelList.size(); ++i) {
sortedLabels.add(
new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] )));
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
}
List<String> result = new ArrayList<>();
final int size = sortedLabels.size();
for (int i = 0; i < size; ++i) {
Map.Entry<String, Float> label = sortedLabels.poll();
result.add(label.getKey() + ":" + label.getValue());
}
Log.d(TAG, "labels: " + result.toString());
return result;
}
/**
* Reads label list from Assets.
*/
private List<String> loadLabelList(Activity activity) {
List<String> labelList = new ArrayList<>();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(activity.getAssets().open
(LABEL_PATH)))) {
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
} catch (IOException e) {
Log.e(TAG, "Failed to read label list.", e);
}
return labelList;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
private synchronized ByteBuffer convertBitmapToByteBuffer(
Bitmap bitmap, int width, int height) {
ByteBuffer imgData =
ByteBuffer.allocateDirect(
4*DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y,
true);
imgData.rewind();
scaledBitmap.getPixels(intValues, 0, scaledBitmap.getWidth(), 0, 0,
scaledBitmap.getWidth(), scaledBitmap.getHeight());
// Convert the image to int points.
int pixel = 0;
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat(((val & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
return imgData;
}
private void showToast(String message) {
Toast.makeText(getApplicationContext(), message, Toast.LENGTH_SHORT).show();
}
public void onItemSelected(AdapterView<?> parent, View v, int position, long id) {
mGraphicOverlay.clear();
mSelectedImage = getBitmapFromAsset(this, mFilePaths[position]);
if (mSelectedImage != null) {
// Get the dimensions of the View
Pair<Integer, Integer> targetedSize = getTargetedWidthHeight();
int targetWidth = targetedSize.first;
int maxHeight = targetedSize.second;
// Determine how much to scale down the image
float scaleFactor =
Math.max(
(float) mSelectedImage.getWidth() / (float) targetWidth,
(float) mSelectedImage.getHeight() / (float) maxHeight);
Bitmap resizedBitmap =
Bitmap.createScaledBitmap(
mSelectedImage,
(int) (mSelectedImage.getWidth() / scaleFactor),
(int) (mSelectedImage.getHeight() / scaleFactor),
true);
mImageView.setImageBitmap(resizedBitmap);
mSelectedImage = resizedBitmap;
}
}
@Override
public void onNothingSelected(AdapterView<?> parent) {
// Do nothing
}
// Utility functions for loading and resizing images from app asset folder.
public static Bitmap getBitmapFromAsset(Context context, String filePath) {
AssetManager assetManager = context.getAssets();
InputStream is;
Bitmap bitmap = null;
try {
is = assetManager.open(filePath);
bitmap = BitmapFactory.decodeStream(is);
} catch (IOException e) {
e.printStackTrace();
}
return bitmap;
}
// Returns max image width, always for portrait mode. Caller needs to swap width / height for
// landscape mode.
private Integer getImageMaxWidth() {
if (mImageMaxWidth == null) {
// Calculate the max width in portrait mode. This is done lazily since we need to
// wait for a UI layout pass to get the right values. So delay it to first time image
// rendering time.
mImageMaxWidth = mImageView.getWidth();
}
return mImageMaxWidth;
}
// Returns max image height, always for portrait mode. Caller needs to swap width / height for
// landscape mode.
private Integer getImageMaxHeight() {
if (mImageMaxHeight == null) {
// Calculate the max width in portrait mode. This is done lazily since we need to
// wait for a UI layout pass to get the right values. So delay it to first time image
// rendering time.
mImageMaxHeight =
mImageView.getHeight();
}
return mImageMaxHeight;
}
// Gets the targeted width / height.
private Pair<Integer, Integer> getTargetedWidthHeight() {
int targetWidth;
int targetHeight;
int maxWidthForPortraitMode = getImageMaxWidth();
int maxHeightForPortraitMode = getImageMaxHeight();
targetWidth = maxWidthForPortraitMode;
targetHeight = maxHeightForPortraitMode;
return new Pair<>(targetWidth, targetHeight);
}
}
但我仍然得到 Failed to get input dimensions. 0-th input should have 268203 bytes, but found 1072812 bytes
的 inception 和 0-th input should have 150528 bytes, but found 602112 bytes
的 mobilenet。所以,一个因素是 4
总是存在的。
要查看我所做的更改,diff original.java changed.java
的输出是:(忽略行号)
32a33,34
> private static final int IMAGE_MEAN = 128;
> private static final float IMAGE_STD = 128.0f;
150,151c152,153
< byte[][] labelProbArray = task.getResult()
< .<byte[][]>getOutput(0);
---
> float[][] labelProbArray = task.getResult()
> .<float[][]>getOutput(0);
170c172
< private synchronized List<String> getTopLabels(byte[][] labelProbArray) {
---
> private synchronized List<String> getTopLabels(float[][] labelProbArray) {
173,174c175
< new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] &
< 0xff) / 255.0f));
---
> new AbstractMap.SimpleEntry<>(mLabelList.get(i), (labelProbArray[0][i] )));
214c215,216
< DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
---
> 4*DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
>
226,228c228,232
< imgData.put((byte) ((val >> 16) & 0xFF));
< imgData.put((byte) ((val >> 8) & 0xFF));
< imgData.put((byte) (val & 0xFF));
---
> imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
> imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
> imgData.putFloat(((val & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
这是代码实验室中缓冲区的分配方式:
ByteBuffer imgData = ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
DIM_BATCH_SIZE - 典型的用法是支持批处理(如果模型支持的话)。在我们的示例和您的测试中,您一次输入 1 张图像并将其保持为 1。 DIM_PIXEL_SIZE - 我们在代码实验室中设置了 3,对应 r/g/b 每个 1 个字节。
但是,看起来您使用的是浮动模型。然后,您使用一个浮点数(4 个字节)来表示每个 r/g/b,而不是每个 r/g/b 一个字节(您已经自己想出了这部分)。那么您使用上述代码分配的缓冲区不再足够。
准确的说imgData人口,分配公式如下:
ByteBuffer imgData = ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * getImageSizeX() * getImageSizeY() * DIM_PIXEL_SIZE
* getNumBytesPerChannel());
getNumBytesPerChannel() 在你的情况下应该是 4。
[更新新问题,关于以下错误]:
获取输入维度失败。第 0 个输入应该有 268203 个字节,但找到了 1072812 个字节
这是检查模型预期的字节数 == 传入的字节数。268203 = 299 * 299 * 3 & 1072812 = 4 * 299 * 299 * 3。看起来你正在使用量化模型,但为它提供了浮动模型的数据。你能仔细检查一下你使用的模型吗?为简单起见,不要指定云模型源并仅使用资产中的本地模型。
[更新0628,开发者说他们训练了一个浮动模型]:
可能是你的模型有误;也可能是您下载了覆盖本地模型的云模型。但是错误消息告诉我们正在加载的模型不是浮动模型。
为了隔离问题,我建议进行以下几项测试: 1)从快速启动应用程序中删除 setCloudModelName / registerCloudModelSource 2) 玩 official TFLite float model 您必须下载评论中提到的模型并更改 Camera2BasicFragment 以使用该 ImageClassifierFloatInception(而不是 ImageClassifierQuantizedMobileNet) 3) 仍然使用相同的 TFLite 示例应用程序,切换到您自己训练的模型。确保将图像大小调整为您的值。