使用 TF 2.0 将 saved_model 转换为 TFLite 模型
Converting saved_model to TFLite model using TF 2.0
目前我正在致力于将自定义对象检测模型(使用 SSD 和初始网络训练)转换为量化的 TFLite 模型。我可以使用以下代码片段(使用 Tensorflow 1.4)将自定义对象检测模型从冻结图转换为量化的 TFLite 模型:
converter = tf.lite.TFLiteConverter.from_frozen_graph(args["model"],input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])
converter.allow_custom_ops=True
converter.post_training_quantize=True
tflite_model = converter.convert()
open(args["output"], "wb").write(tflite_model)
但是 tf.lite.TFLiteConverter.from_frozen_graph
class 方法不适用于 Tensorflow 2.0 (refer this link)。所以我尝试使用 tf.lite.TFLiteConverter.from_saved_model
class 方法转换模型。代码片段如下所示:
converter = tf.lite.TFLiteConverter.from_saved_model("/content/") # Path to saved_model directory
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
以上代码片段抛出以下错误:
ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.
我试图将 input_shapes
作为参数传递
converter = tf.lite.TFLiteConverter.from_saved_model("/content/",input_shapes={"image_tensor" : [1,300,300,3]})
但它抛出以下错误:
TypeError: from_saved_model() got an unexpected keyword argument 'input_shapes'
我错过了什么吗?请随时指正!
我使用 tf.compat.v1.lite.TFLiteConverter.from_frozen_graph
得到了解决方案。 compat.v1
将 TF1.x
的功能引入 TF2.x
。
以下是完整代码:
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph("/content/tflite_graph.pb",input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])
converter.allow_custom_ops=True
# Convert the model to quantized TFLite model.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Write a model using the following line
open("/content/uno_mobilenetV2.tflite", "wb").write(tflite_model)
目前我正在致力于将自定义对象检测模型(使用 SSD 和初始网络训练)转换为量化的 TFLite 模型。我可以使用以下代码片段(使用 Tensorflow 1.4)将自定义对象检测模型从冻结图转换为量化的 TFLite 模型:
converter = tf.lite.TFLiteConverter.from_frozen_graph(args["model"],input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])
converter.allow_custom_ops=True
converter.post_training_quantize=True
tflite_model = converter.convert()
open(args["output"], "wb").write(tflite_model)
但是 tf.lite.TFLiteConverter.from_frozen_graph
class 方法不适用于 Tensorflow 2.0 (refer this link)。所以我尝试使用 tf.lite.TFLiteConverter.from_saved_model
class 方法转换模型。代码片段如下所示:
converter = tf.lite.TFLiteConverter.from_saved_model("/content/") # Path to saved_model directory
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
以上代码片段抛出以下错误:
ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.
我试图将 input_shapes
作为参数传递
converter = tf.lite.TFLiteConverter.from_saved_model("/content/",input_shapes={"image_tensor" : [1,300,300,3]})
但它抛出以下错误:
TypeError: from_saved_model() got an unexpected keyword argument 'input_shapes'
我错过了什么吗?请随时指正!
我使用 tf.compat.v1.lite.TFLiteConverter.from_frozen_graph
得到了解决方案。 compat.v1
将 TF1.x
的功能引入 TF2.x
。
以下是完整代码:
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph("/content/tflite_graph.pb",input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])
converter.allow_custom_ops=True
# Convert the model to quantized TFLite model.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Write a model using the following line
open("/content/uno_mobilenetV2.tflite", "wb").write(tflite_model)