未验证 提交 27da92b0 编写于 作者: W Wanli 提交者: GitHub

Let mp_palmdet support multiple palms detection (#85)

上级 2269c87e
......@@ -18,4 +18,5 @@ Model:
modelPath: "models/palm_detection_mediapipe/palm_detection_mediapipe_2022may.onnx"
scoreThreshold: 0.5
nmsThreshold: 0.3
topK: 1
......@@ -36,23 +36,37 @@ parser.add_argument('--save', '-s', type=str, default=False, help='Set true to s
parser.add_argument('--vis', '-v', type=str2bool, default=True, help='Set true to open a window for result visualization. This flag is invalid when using camera.')
args = parser.parse_args()
def visualize(image, score, palm_box, palm_landmarks, fps=None):
def visualize(image, results, print_results=False, fps=None):
output = image.copy()
if fps is not None:
cv.putText(output, 'FPS: {:.2f}'.format(fps), (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255))
# put score
palm_box = palm_box.astype(np.int32)
cv.putText(output, '{:.4f}'.format(score), (palm_box[0], palm_box[1]+12), cv.FONT_HERSHEY_DUPLEX, 0.5, (0, 255, 0))
for idx, palm in enumerate(results):
score = palm[-1]
palm_box = palm[0:4]
palm_landmarks = palm[4:-1].reshape(7, 2)
# draw box
cv.rectangle(output, (palm_box[0], palm_box[1]), (palm_box[2], palm_box[3]), (0, 255, 0), 2)
# put score
palm_box = palm_box.astype(np.int32)
cv.putText(output, '{:.4f}'.format(score), (palm_box[0], palm_box[1]+12), cv.FONT_HERSHEY_DUPLEX, 0.5, (0, 255, 0))
# draw points
palm_landmarks = palm_landmarks.astype(np.int32)
for p in palm_landmarks:
cv.circle(output, p, 2, (0, 0, 255), 2)
# draw box
cv.rectangle(output, (palm_box[0], palm_box[1]), (palm_box[2], palm_box[3]), (0, 255, 0), 2)
# draw points
palm_landmarks = palm_landmarks.astype(np.int32)
for p in palm_landmarks:
cv.circle(output, p, 2, (0, 0, 255), 2)
# Print results
if print_results:
print('-----------palm {}-----------'.format(idx + 1))
print('score: {:.2f}'.format(score))
print('palm box: {}'.format(palm_box))
print('palm landmarks: ')
for plm in palm_landmarks:
print('\t{}'.format(plm))
return output
......@@ -69,30 +83,23 @@ if __name__ == '__main__':
image = cv.imread(args.input)
# Inference
score, palm_box, palm_landmarks = model.infer(image)
if score is None or palm_box is None or palm_landmarks is None:
results = model.infer(image)
if len(results) == 0:
print('Hand not detected')
else:
# Print results
print('score: {:.2f}'.format(score))
print('palm box: {}'.format(palm_box))
print('palm_landmarks: ')
for plm in enumerate(palm_landmarks):
print('\t{}'.format(plm))
# Draw results on the input image
image = visualize(image, score, palm_box, palm_landmarks)
# Save results if save is true
if args.save:
print('Resutls saved to result.jpg\n')
cv.imwrite('result.jpg', image)
# Visualize results in a new window
if args.vis:
cv.namedWindow(args.input, cv.WINDOW_AUTOSIZE)
cv.imshow(args.input, image)
cv.waitKey(0)
# Draw results on the input image
image = visualize(image, results, print_results=True)
# Save results if save is true
if args.save:
print('Resutls saved to result.jpg\n')
cv.imwrite('result.jpg', image)
# Visualize results in a new window
if args.vis:
cv.namedWindow(args.input, cv.WINDOW_AUTOSIZE)
cv.imshow(args.input, image)
cv.waitKey(0)
else: # Omit input to call default camera
deviceId = 0
cap = cv.VideoCapture(deviceId)
......@@ -106,12 +113,11 @@ if __name__ == '__main__':
# Inference
tm.start()
score, palm_box, palm_landmarks = model.infer(frame)
results = model.infer(frame)
tm.stop()
# Draw results on the input image
if score is not None and palm_box is not None and palm_landmarks is not None:
frame = visualize(frame, score, palm_box, palm_landmarks, fps=tm.getFPS())
frame = visualize(frame, results, fps=tm.getFPS())
# Visualize results in a new Window
cv.imshow('MPPalmDet Demo', frame)
......
......@@ -2,10 +2,11 @@ import numpy as np
import cv2 as cv
class MPPalmDet:
def __init__(self, modelPath, nmsThreshold=0.3, scoreThreshold=0.5, backendId=0, targetId=0):
def __init__(self, modelPath, nmsThreshold=0.3, scoreThreshold=0.5, topK=5000, backendId=0, targetId=0):
self.model_path = modelPath
self.nms_threshold = nmsThreshold
self.score_threshold = scoreThreshold
self.topK = topK
self.backend_id = backendId
self.target_id = targetId
......@@ -48,9 +49,9 @@ class MPPalmDet:
output_blob = self.model.forward()
# Postprocess
score, palm_box, palm_landmarks = self._postprocess(output_blob, np.array([w, h]))
results = self._postprocess(output_blob, np.array([w, h]))
return (score, palm_box, palm_landmarks)
return results
def _postprocess(self, output_blob, original_shape):
score = output_blob[0, :, 0]
......@@ -68,17 +69,26 @@ class MPPalmDet:
xy2 = (cxy_delta + wh_delta / 2 + self.anchors[:, :2]) * original_shape
boxes = np.concatenate([xy1, xy2], axis=1)
# NMS
keep_idx = cv.dnn.NMSBoxes(boxes, score, self.score_threshold, self.nms_threshold, top_k=1)
keep_idx = cv.dnn.NMSBoxes(boxes, score, self.score_threshold, self.nms_threshold, top_k=self.topK)
if len(keep_idx) == 0:
return None, None, None
selected_score = score[keep_idx][0]
selected_box = boxes[keep_idx][0]
return np.empty(shape=(0, 19))
selected_score = score[keep_idx]
selected_box = boxes[keep_idx]
# get landmarks
selected_landmarks = landmark_delta[keep_idx].reshape(7, 2)
selected_landmarks = (selected_landmarks / self.input_size + self.anchors[keep_idx]) * original_shape
selected_landmarks = landmark_delta[keep_idx].reshape(-1, 7, 2)
selected_landmarks = selected_landmarks / self.input_size
selected_anchors = self.anchors[keep_idx]
for idx, landmark in enumerate(selected_landmarks):
landmark += selected_anchors[idx]
selected_landmarks *= original_shape
return (selected_score, selected_box, selected_landmarks)
# [
# [bbox_coords, landmarks_coords, score]
# ...
# [bbox_coords, landmarks_coords, score]
# ]
return np.c_[selected_box.reshape(-1, 4), selected_landmarks.reshape(-1, 14), selected_score.reshape(-1, 1)]
def _load_anchors(self):
return np.array([[0.015625, 0.015625],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册