获取使用 Keras 模型检测到的对象的位置

Get location of object detected using Keras Model

好的,所以,我目前正在使用此代码:

import cv2
from PIL import Image, ImageOps
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model

classesint = []
classesname = []
file = open("labels.txt", "r")

for line in file:
  stripped_line = line.strip()
  line_list = stripped_line.split()
  classesint.append(int(line_list[0]))
  line_list.remove(line_list[0])

  classesname.append(" ".join(line_list))
  print(classesint)
  print(classesname)

file.close()
model = load_model('keras_model.h5')

vid = cv2.VideoCapture(0, cv2.CAP_DSHOW)

while True:

    ret, frame = vid.read()
    if not ret:
        pass
    else:

        data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
        image = Image.fromarray(frame)

        size = (224, 224)
        image = ImageOps.fit(image, size, Image.ANTIALIAS)

        image_array = np.asarray(image)
        normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
        data[0] = normalized_image_array

        prediction = model.predict(data)
        clas = np.argmax(prediction, axis = 1)
        threshold = 97
        detected = False
        for i in classesint:
            if prediction[0][i]*100 >= threshold:
                detected = True
        if detected:
            # for i in classesint:
            #     print(prediction[0][i]*100)
            print(classesname[int(clas)])
        
        cv2.imshow('frame', frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

vid.release()
cv2.destroyAllWindows()

进行实时网络摄像头对象检测,并且工作正常,但我想知道我如何才能看到它检测到该对象的位置?我需要重写整个文件吗?或者我可以添加一些可以做到这一点的班轮吗?

您似乎在使用 class化模型,它输出您输入的图像类型。如果你使用对象检测模型,比如 YOLO 你也可以找到位置,这完全取决于你如何训练模型。

分类模型使用 class 个标签进行训练,检测模型使用 class 个标签 + 对象位置进行训练。因此,他们产生了他们接受训练的输出。