diff --git a/deploy/python/det_keypoint_unite_infer.py b/deploy/python/det_keypoint_unite_infer.py index c3295559778e2a7c61a68e36cb3971cb3e83f7f7..951864b2a1a86b111f9bb829617133a3b4ea5f98 100644 --- a/deploy/python/det_keypoint_unite_infer.py +++ b/deploy/python/det_keypoint_unite_infer.py @@ -143,6 +143,9 @@ def topdown_unite_predict_video(detector, writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) index = 0 store_res = [] + previous_keypoints = None + keypoint_smoothing = KeypointSmoothing(width, height, filter_type=FLAGS.filter_type, alpha=0.8, beta=1) + while (1): ret, frame = capture.read() if not ret: @@ -161,12 +164,20 @@ def topdown_unite_predict_video(detector, keypoint_res = predict_with_given_det( frame2, results, topdown_keypoint_detector, keypoint_batch_size, FLAGS.run_benchmark) + + if FLAGS.smooth: + current_keypoints = np.array(keypoint_res['keypoint'][0][0]) + smooth_keypoints = keypoint_smoothing.smooth_process(previous_keypoints, current_keypoints) + previous_keypoints = smooth_keypoints + + keypoint_res['keypoint'][0][0] = smooth_keypoints.tolist() im = visualize_pose( frame, keypoint_res, visual_thresh=FLAGS.keypoint_threshold, returnimg=True) + if save_res: store_res.append([ index, keypoint_res['bbox'], @@ -192,6 +203,77 @@ def topdown_unite_predict_video(detector, json.dump(store_res, wf, indent=4) +class KeypointSmoothing(object): + # The following code are modified from: + # https://github.com/610265158/Peppa_Pig_Face_Engine/blob/7bb1066ad3fbb12697924ba7f9287bf198c15232/lib/core/LK/lk.py + + def __init__(self, width, height, filter_type, alpha=0.5, fc_d=1, fc_min=1, beta=0): + super(KeypointSmoothing, self).__init__() + self.image_width = width + self.image_height = height + self.threshold = [0.005, 0.005, 0.005, 0.005, 0.005, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] + self.filter_type = filter_type + self.alpha = alpha + self.dx_prev_hat = None + self.x_prev_hat = None + self.fc_d = fc_d + self.fc_min = fc_min + self.beta = beta + + if self.filter_type == 'one_euro': + self.smooth_func = self.one_euro_filter + elif self.filter_type == 'ema': + self.smooth_func = self.exponential_smoothing + else: + raise ValueError('filter type must be one_euro or ema') + + def smooth_process(self, previous_keypoints, current_keypoints): + if previous_keypoints is None: + previous_keypoints = current_keypoints + result = current_keypoints + else: + result = [] + num_keypoints = len(current_keypoints) + for i in range(num_keypoints): + result.append(self.smooth(previous_keypoints[i], current_keypoints[i], self.threshold[i])) + return np.array(result) + + + def smooth(self, previous_keypoint, current_keypoint, threshold): + distance = np.sqrt(np.square((current_keypoint[0] - previous_keypoint[0]) / self.image_width) + np.square((current_keypoint[1] - previous_keypoint[1]) / self.image_height)) + if distance < threshold: + result = previous_keypoint + else: + result = self.smooth_func(previous_keypoint, current_keypoint) + return result + + + def one_euro_filter(self, x_prev, x_cur): + te = 1 + self.alpha = self.smoothing_factor(te, self.fc_d) + if self.x_prev_hat is None: + self.x_prev_hat = x_prev + dx_cur = (x_cur - self.x_prev_hat) / te + if self.dx_prev_hat is None: + self.dx_prev_hat = 0 + dx_cur_hat = self.exponential_smoothing(self.dx_prev_hat, dx_cur) + + fc = self.fc_min + self.beta * np.abs(dx_cur_hat) + self.alpha = self.smoothing_factor(te, fc) + x_cur_hat = self.exponential_smoothing(self.x_prev_hat, x_cur) + self.dx_prev_hat = dx_cur_hat + self.x_prev_hat = x_cur_hat + return x_cur_hat + + + def smoothing_factor(self, te, fc): + r = 2 * math.pi * fc * te + return r / (r + 1) + + def exponential_smoothing(self, x_prev, x_cur): + return self.alpha * x_cur + (1 - self.alpha) * x_prev + + def main(): deploy_file = os.path.join(FLAGS.det_model_dir, 'infer_cfg.yml') with open(deploy_file) as f: diff --git a/deploy/python/det_keypoint_unite_utils.py b/deploy/python/det_keypoint_unite_utils.py index 26344628a3e10457a394f351fc64f7049a4245bb..129969df2db0e079e363a09c33861d1ff3cc6ca2 100644 --- a/deploy/python/det_keypoint_unite_utils.py +++ b/deploy/python/det_keypoint_unite_utils.py @@ -126,4 +126,16 @@ def argsparser(): "3) rects: list of rect [xmin, ymin, xmax, ymax]" "4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list" "5) scores: mean of all joint conf")) + parser.add_argument( + '--smooth', + type=ast.literal_eval, + default=False, + help='smoothing keypoints for each frame, new incoming keypoints will be more stable.' + ) + parser.add_argument( + '--filter_type', + type=str, + default='one_euro', + help='when set --smooth True, choose filter type you want to use, it can be one_euro or ema.' + ) return parser