如何对视频输入进行 TFLite 模型推理
How to make TFLite model Inference on video input
我正在尝试测试我导出的 Mobilenet v2 SSDLite
模型(https://drive.google.com/open?id=1htyBE6R62yVCV8v-9muEJ_lGmoPxQMmJ) with video. Then i found an answer ,我修改了某处以适应我的模型:
import cv2
from PIL import Image
import numpy as np
import tensorflow as tf
def read_tensor_from_readed_frame(frame, input_height=300, input_width=300,
input_mean=128, input_std=128):
output_name = "normalized"
# float_caster = tf.cast(frame, tf.float32)
float_caster = tf.cast(frame, tf.uint8)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
def VideoSrcInit(paath):
cap = cv2.VideoCapture(paath)
flag, image = cap.read()
if flag:
print("Valid Video Path. Lets move to detection!")
else:
raise ValueError("Video Initialization Failed. Please make sure video path is valid.")
return cap
def main():
Labels_Path = "C:/MachineLearning/CV/coco-labelmap.txt"
Model_Path = "C:/MachineLearning/CV/previous_float_model_converted_from_ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tflite"
input_path = "C:/MachineLearning/CV/Object_Tracking/video2.mp4"
##Loading labels
labels = load_labels(Labels_Path)
##Load tflite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path=Model_Path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
##Read video
cap = VideoSrcInit(input_path)
while True:
ok, cv_image = cap.read()
if not ok:
break
##Converting the readed frame to RGB as opencv reads frame in BGR
image = Image.fromarray(cv_image).convert('RGB')
##Converting image into tensor
image_tensor = read_tensor_from_readed_frame(image ,300, 300)
##Test model
interpreter.set_tensor(input_details[0]['index'], image_tensor)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
## You need to check the output of the output_data variable and
## map it on the frame in order to draw the bounding boxes.
cv2.namedWindow("cv_image", cv2.WINDOW_NORMAL)
cv2.imshow("cv_image",cv_image)
##Use p to pause the video and use q to termiate the program
key = cv2.waitKey(10) & 0xFF
if key == ord("q"):
break
elif key == ord("p"):
cv2.waitKey(0)
continue
cap.release()
if __name__ == '__main__':
main()
当我 运行 在我的 tflite 模型上运行这个脚本时,FPS 非常非常慢,几乎静止不动,那么脚本有什么问题?
我自己解决的,这是脚本:
import numpy as np
import tensorflow as tf
import cv2
import time
print(tf.__version__)
Model_Path = "C:/MachineLearning/CV/uint8_dequantized_model_converted_from_exported_model.tflite"
Video_path = "C:/MachineLearning/CV/Object_Tracking/video2.mp4"
interpreter = tf.lite.Interpreter(model_path=Model_Path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane','bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant ', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', ' cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', ' cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
cap = cv2.VideoCapture(Video_path)
ok, frame_image = cap.read()
original_image_height, original_image_width, _ = frame_image.shape
thickness = original_image_height // 500
fontsize = original_image_height / 1500
print(thickness)
print(fontsize)
while True:
ok, frame_image = cap.read()
if not ok:
break
model_interpreter_start_time = time.time()
resize_img = cv2.resize(frame_image, (300, 300), interpolation=cv2.INTER_CUBIC)
reshape_image = resize_img.reshape(300, 300, 3)
image_np_expanded = np.expand_dims(reshape_image, axis=0)
image_np_expanded = image_np_expanded.astype('uint8') # float32
interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
output_data_1 = interpreter.get_tensor(output_details[1]['index'])
output_data_2 = interpreter.get_tensor(output_details[2]['index'])
output_data_3 = interpreter.get_tensor(output_details[3]['index'])
each_interpreter_time = time.time() - model_interpreter_start_time
for i in range(len(output_data_1[0])):
confidence_threshold = output_data_2[0][i]
if confidence_threshold > 0.3:
label = "{}: {:.2f}% ".format(class_names[int(output_data_1[0][i])], output_data_2[0][i] * 100)
label2 = "inference time : {:.3f}s" .format(each_interpreter_time)
left_up_corner = (int(output_data[0][i][1]*original_image_width), int(output_data[0][i][0]*original_image_height))
left_up_corner_higher = (int(output_data[0][i][1]*original_image_width), int(output_data[0][i][0]*original_image_height)-20)
right_down_corner = (int(output_data[0][i][3]*original_image_width), int(output_data[0][i][2]*original_image_height))
cv2.rectangle(frame_image, left_up_corner_higher, right_down_corner, (0, 255, 0), thickness)
cv2.putText(frame_image, label, left_up_corner_higher, cv2.FONT_HERSHEY_DUPLEX, fontsize, (255, 255, 255), thickness=thickness)
cv2.putText(frame_image, label2, (30, 30), cv2.FONT_HERSHEY_DUPLEX, fontsize, (255, 255, 255), thickness=thickness)
cv2.namedWindow('detect_result', cv2.WINDOW_NORMAL)
# cv2.resizeWindow('detect_result', 800, 600)
cv2.imshow("detect_result", frame_image)
key = cv2.waitKey(10) & 0xFF
if key == ord("q"):
break
elif key == 32:
cv2.waitKey(0)
continue
cap.release()
cv2.destroyAllWindows()
但是推理速度仍然很慢,因为 tflite 操作是针对移动设备优化的,而不是针对桌面设备优化的。
我正在尝试测试我导出的 Mobilenet v2 SSDLite
模型(https://drive.google.com/open?id=1htyBE6R62yVCV8v-9muEJ_lGmoPxQMmJ) with video. Then i found an answer
import cv2
from PIL import Image
import numpy as np
import tensorflow as tf
def read_tensor_from_readed_frame(frame, input_height=300, input_width=300,
input_mean=128, input_std=128):
output_name = "normalized"
# float_caster = tf.cast(frame, tf.float32)
float_caster = tf.cast(frame, tf.uint8)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
def VideoSrcInit(paath):
cap = cv2.VideoCapture(paath)
flag, image = cap.read()
if flag:
print("Valid Video Path. Lets move to detection!")
else:
raise ValueError("Video Initialization Failed. Please make sure video path is valid.")
return cap
def main():
Labels_Path = "C:/MachineLearning/CV/coco-labelmap.txt"
Model_Path = "C:/MachineLearning/CV/previous_float_model_converted_from_ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tflite"
input_path = "C:/MachineLearning/CV/Object_Tracking/video2.mp4"
##Loading labels
labels = load_labels(Labels_Path)
##Load tflite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path=Model_Path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
##Read video
cap = VideoSrcInit(input_path)
while True:
ok, cv_image = cap.read()
if not ok:
break
##Converting the readed frame to RGB as opencv reads frame in BGR
image = Image.fromarray(cv_image).convert('RGB')
##Converting image into tensor
image_tensor = read_tensor_from_readed_frame(image ,300, 300)
##Test model
interpreter.set_tensor(input_details[0]['index'], image_tensor)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
## You need to check the output of the output_data variable and
## map it on the frame in order to draw the bounding boxes.
cv2.namedWindow("cv_image", cv2.WINDOW_NORMAL)
cv2.imshow("cv_image",cv_image)
##Use p to pause the video and use q to termiate the program
key = cv2.waitKey(10) & 0xFF
if key == ord("q"):
break
elif key == ord("p"):
cv2.waitKey(0)
continue
cap.release()
if __name__ == '__main__':
main()
当我 运行 在我的 tflite 模型上运行这个脚本时,FPS 非常非常慢,几乎静止不动,那么脚本有什么问题?
我自己解决的,这是脚本:
import numpy as np
import tensorflow as tf
import cv2
import time
print(tf.__version__)
Model_Path = "C:/MachineLearning/CV/uint8_dequantized_model_converted_from_exported_model.tflite"
Video_path = "C:/MachineLearning/CV/Object_Tracking/video2.mp4"
interpreter = tf.lite.Interpreter(model_path=Model_Path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane','bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant ', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', ' cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', ' cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
cap = cv2.VideoCapture(Video_path)
ok, frame_image = cap.read()
original_image_height, original_image_width, _ = frame_image.shape
thickness = original_image_height // 500
fontsize = original_image_height / 1500
print(thickness)
print(fontsize)
while True:
ok, frame_image = cap.read()
if not ok:
break
model_interpreter_start_time = time.time()
resize_img = cv2.resize(frame_image, (300, 300), interpolation=cv2.INTER_CUBIC)
reshape_image = resize_img.reshape(300, 300, 3)
image_np_expanded = np.expand_dims(reshape_image, axis=0)
image_np_expanded = image_np_expanded.astype('uint8') # float32
interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
output_data_1 = interpreter.get_tensor(output_details[1]['index'])
output_data_2 = interpreter.get_tensor(output_details[2]['index'])
output_data_3 = interpreter.get_tensor(output_details[3]['index'])
each_interpreter_time = time.time() - model_interpreter_start_time
for i in range(len(output_data_1[0])):
confidence_threshold = output_data_2[0][i]
if confidence_threshold > 0.3:
label = "{}: {:.2f}% ".format(class_names[int(output_data_1[0][i])], output_data_2[0][i] * 100)
label2 = "inference time : {:.3f}s" .format(each_interpreter_time)
left_up_corner = (int(output_data[0][i][1]*original_image_width), int(output_data[0][i][0]*original_image_height))
left_up_corner_higher = (int(output_data[0][i][1]*original_image_width), int(output_data[0][i][0]*original_image_height)-20)
right_down_corner = (int(output_data[0][i][3]*original_image_width), int(output_data[0][i][2]*original_image_height))
cv2.rectangle(frame_image, left_up_corner_higher, right_down_corner, (0, 255, 0), thickness)
cv2.putText(frame_image, label, left_up_corner_higher, cv2.FONT_HERSHEY_DUPLEX, fontsize, (255, 255, 255), thickness=thickness)
cv2.putText(frame_image, label2, (30, 30), cv2.FONT_HERSHEY_DUPLEX, fontsize, (255, 255, 255), thickness=thickness)
cv2.namedWindow('detect_result', cv2.WINDOW_NORMAL)
# cv2.resizeWindow('detect_result', 800, 600)
cv2.imshow("detect_result", frame_image)
key = cv2.waitKey(10) & 0xFF
if key == ord("q"):
break
elif key == 32:
cv2.waitKey(0)
continue
cap.release()
cv2.destroyAllWindows()
但是推理速度仍然很慢,因为 tflite 操作是针对移动设备优化的,而不是针对桌面设备优化的。