如何使用 C++ api 在 tflite 中获取权重?
How to get weights in tflite using c++ api?
我在设备上使用 .tflite 模型。最后一层是 ConditionalRandomField 层,我需要该层的权重来进行预测。
我如何使用 C++ api 获得权重?
相关:How can I view weights in a .tflite file?
Netron 或 flatc 不能满足我的需求。对设备来说太重了。
TfLiteNode 似乎将权重存储在 void* user_data 或 void* builtin_data 中。我如何阅读它们?
更新:
结论:.tflite 不存储 CRF 权重,而 .h5 剂量。 (可能是因为它们不影响输出。)
我做什么:
// obtain from model.
Interpreter *interpreter;
// get the last index of nodes.
// I'm not sure if the index sequence of nodes is the direction which tensors or layers flows.
const TfLiteNode *node = &((interpreter->node_and_registration(interpreter->nodes_size()-1))->first);
// then follow the answer of @yyoon
在TFLite节点中,权重应该存储在inputs
数组中,其中包含对应TfLiteTensor*
.
的索引
所以,如果你已经得到了最后一层的TfLiteNode*
,你可以这样做来读取权重值。
TfLiteContext* context; // You would usually have access to this already.
TfLiteNode* node; // <obtain this from the graph>;
for (int i = 0; i < node->inputs->size; ++i) {
TfLiteTensor* input_tensor = GetInput(context, node, i);
// Determine if this is a weight tensor.
// Usually the weights will be memory-mapped read-only tensor
// directly baked in the TFLite model (flatbuffer).
if (input_tensor->allocation_type == kTfLiteMmapRo) {
// Read the values from input_tensor, based on its type.
// For example, if you have float weights,
const float* weights = GetTensorData<float>(input_tensor);
// <read the weight values...>
}
}
我在设备上使用 .tflite 模型。最后一层是 ConditionalRandomField 层,我需要该层的权重来进行预测。 我如何使用 C++ api 获得权重?
相关:How can I view weights in a .tflite file?
Netron 或 flatc 不能满足我的需求。对设备来说太重了。
TfLiteNode 似乎将权重存储在 void* user_data 或 void* builtin_data 中。我如何阅读它们?
更新:
结论:.tflite 不存储 CRF 权重,而 .h5 剂量。 (可能是因为它们不影响输出。)
我做什么:
// obtain from model.
Interpreter *interpreter;
// get the last index of nodes.
// I'm not sure if the index sequence of nodes is the direction which tensors or layers flows.
const TfLiteNode *node = &((interpreter->node_and_registration(interpreter->nodes_size()-1))->first);
// then follow the answer of @yyoon
在TFLite节点中,权重应该存储在inputs
数组中,其中包含对应TfLiteTensor*
.
所以,如果你已经得到了最后一层的TfLiteNode*
,你可以这样做来读取权重值。
TfLiteContext* context; // You would usually have access to this already.
TfLiteNode* node; // <obtain this from the graph>;
for (int i = 0; i < node->inputs->size; ++i) {
TfLiteTensor* input_tensor = GetInput(context, node, i);
// Determine if this is a weight tensor.
// Usually the weights will be memory-mapped read-only tensor
// directly baked in the TFLite model (flatbuffer).
if (input_tensor->allocation_type == kTfLiteMmapRo) {
// Read the values from input_tensor, based on its type.
// For example, if you have float weights,
const float* weights = GetTensorData<float>(input_tensor);
// <read the weight values...>
}
}