获取使用 Keras 模型检测到的对象的位置
Get location of object detected using Keras Model
好的,所以,我目前正在使用此代码:
import cv2
from PIL import Image, ImageOps
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
classesint = []
classesname = []
file = open("labels.txt", "r")
for line in file:
stripped_line = line.strip()
line_list = stripped_line.split()
classesint.append(int(line_list[0]))
line_list.remove(line_list[0])
classesname.append(" ".join(line_list))
print(classesint)
print(classesname)
file.close()
model = load_model('keras_model.h5')
vid = cv2.VideoCapture(0, cv2.CAP_DSHOW)
while True:
ret, frame = vid.read()
if not ret:
pass
else:
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
image = Image.fromarray(frame)
size = (224, 224)
image = ImageOps.fit(image, size, Image.ANTIALIAS)
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
data[0] = normalized_image_array
prediction = model.predict(data)
clas = np.argmax(prediction, axis = 1)
threshold = 97
detected = False
for i in classesint:
if prediction[0][i]*100 >= threshold:
detected = True
if detected:
# for i in classesint:
# print(prediction[0][i]*100)
print(classesname[int(clas)])
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
vid.release()
cv2.destroyAllWindows()
进行实时网络摄像头对象检测,并且工作正常,但我想知道我如何才能看到它检测到该对象的位置?我需要重写整个文件吗?或者我可以添加一些可以做到这一点的班轮吗?
您似乎在使用 class化模型,它输出您输入的图像类型。如果你使用对象检测模型,比如 YOLO 你也可以找到位置,这完全取决于你如何训练模型。
分类模型使用 class 个标签进行训练,检测模型使用 class 个标签 + 对象位置进行训练。因此,他们产生了他们接受训练的输出。
好的,所以,我目前正在使用此代码:
import cv2
from PIL import Image, ImageOps
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
classesint = []
classesname = []
file = open("labels.txt", "r")
for line in file:
stripped_line = line.strip()
line_list = stripped_line.split()
classesint.append(int(line_list[0]))
line_list.remove(line_list[0])
classesname.append(" ".join(line_list))
print(classesint)
print(classesname)
file.close()
model = load_model('keras_model.h5')
vid = cv2.VideoCapture(0, cv2.CAP_DSHOW)
while True:
ret, frame = vid.read()
if not ret:
pass
else:
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
image = Image.fromarray(frame)
size = (224, 224)
image = ImageOps.fit(image, size, Image.ANTIALIAS)
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
data[0] = normalized_image_array
prediction = model.predict(data)
clas = np.argmax(prediction, axis = 1)
threshold = 97
detected = False
for i in classesint:
if prediction[0][i]*100 >= threshold:
detected = True
if detected:
# for i in classesint:
# print(prediction[0][i]*100)
print(classesname[int(clas)])
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
vid.release()
cv2.destroyAllWindows()
进行实时网络摄像头对象检测,并且工作正常,但我想知道我如何才能看到它检测到该对象的位置?我需要重写整个文件吗?或者我可以添加一些可以做到这一点的班轮吗?
您似乎在使用 class化模型,它输出您输入的图像类型。如果你使用对象检测模型,比如 YOLO 你也可以找到位置,这完全取决于你如何训练模型。
分类模型使用 class 个标签进行训练,检测模型使用 class 个标签 + 对象位置进行训练。因此,他们产生了他们接受训练的输出。