From ccea568d329f583d9e9d67a4aceb32e33dfdc942 Mon Sep 17 00:00:00 2001 From: sjtubinlong Date: Tue, 31 Mar 2020 20:33:50 +0800 Subject: [PATCH] fix coding style --- contrib/RealTimeHumanSeg/python/infer.py | 101 ++++++++++++++--------- 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/contrib/RealTimeHumanSeg/python/infer.py b/contrib/RealTimeHumanSeg/python/infer.py index 66110bf3..4607c8f6 100644 --- a/contrib/RealTimeHumanSeg/python/infer.py +++ b/contrib/RealTimeHumanSeg/python/infer.py @@ -14,20 +14,47 @@ # limitations under the License. import os -import sys - +import argparse import numpy as np import cv2 import paddle.fluid as fluid +def parse_args(): + """ + Parsing command argments + """ + parser = argparse.ArgumentParser('Realtime Human Segmentation') + parser.add_argument('--model_dir', + type=str, + default='', + help='path of human segmentation model') + parser.add_argument('--img_path', + type=str, + default='', + help='path of input image') + parser.add_argument('--video_path', + type=str, + default='', + help='path of input video') + parser.add_argument('--use_camera', + type=bool, + default=False, + help='input video stream from camera') + parser.add_argument('--use_gpu', + type=bool, + default=False, + help='enable gpu') + return parser.parse_args() + + def get_round(data): """ get round of data """ - round = 0.5 if data >= 0 else -0.5 - return (int)(data + round) + rnd = 0.5 if data >= 0 else -0.5 + return (int)(data + rnd) def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): @@ -35,29 +62,30 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): human segmentation tracking """ check_thres = 8 - h, w = pre_gray.shape[:2] + hgt, wdh = pre_gray.shape[:2] track_cfd = np.zeros_like(prev_cfd) is_track = np.zeros_like(pre_gray) flow_fw = disflow.calc(pre_gray, cur_gray, None) flow_bw = disflow.calc(cur_gray, pre_gray, None) - for r in range(h): - for c in range(w): - fxy_fw = flow_fw[r, c] + for row in range(hgt): + for col in range(wdh): + fxy_fw = flow_fw[row, col] dx_fw = get_round(fxy_fw[0]) - cur_x = dx_fw + c + cur_x = dx_fw + col dy_fw = get_round(fxy_fw[1]) - cur_y = dy_fw + r - if cur_x < 0 or cur_x >= w or cur_y < 0 or cur_y >= h: + cur_y = dy_fw + row + if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt: continue fxy_bw = flow_bw[cur_y, cur_x] dx_bw = get_round(fxy_bw[0]) dy_bw = get_round(fxy_bw[1]) - if ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) >= check_thres: + lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) + if lmt >= check_thres: continue if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0: dl_weights[cur_y, cur_x] = 0.05 is_track[cur_y, cur_x] = 1 - track_cfd[cur_y, cur_x] = prev_cfd[r, c] + track_cfd[cur_y, cur_x] = prev_cfd[row, col] return track_cfd, is_track, dl_weights @@ -78,7 +106,7 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): else: cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score else: - cur_cfd[x, y] = dl_weights[x,y]*dl_score + (1-dl_weights[x,y])*track_score + cur_cfd[x, y] = dl_weights[x, y] * dl_score + (1 - dl_weights[x, y]) * track_score return cur_cfd @@ -96,22 +124,23 @@ def optflow_handle(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init): """ optical flow handling """ - w, h = scoremap.shape[0], scoremap.shape[1] + width, height = scoremap.shape[0], scoremap.shape[1] cur_cfd = scoremap.copy() if is_init: is_init = False - if h <= 64 or w <= 64: + if height <= 64 or width <= 64: disflow.setFinestScale(1) - elif h <= 160 or w <= 160: + elif height <= 160 or width <= 160: disflow.setFinestScale(2) else: disflow.setFinestScale(3) fusion_cfd = cur_cfd else: - weights = np.ones((w,h), np.float32) * 0.3 - track_cfd, is_track, weights = human_seg_tracking(prev_gray, cur_gray, pre_cfd, weights, disflow) + weights = np.ones((width, height), np.float32) * 0.3 + track_cfd, is_track, weights = human_seg_tracking( + prev_gray, cur_gray, pre_cfd, weights, disflow) fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track) - fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3,3), 0) + fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0) return fusion_cfd @@ -179,8 +208,8 @@ class HumanSeg: optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8) optflow_map = cv2.resize(optflow_map, (ori_w, ori_h)) optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) - bg = np.ones_like(optflow_map) * 255 - comb = (optflow_map * image + (1 - optflow_map) * bg).astype(np.uint8) + bg_im = np.ones_like(optflow_map) * 255 + comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(np.uint8) return comb def run_predict(self, image): @@ -218,16 +247,12 @@ def predict_video(seg, video_path): out = cv2.VideoWriter('result.avi', cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height)) - id = 1 # Start capturing from video while cap.isOpened(): ret, frame = cap.read() if ret: img_mat = seg.run_predict(frame) out.write(img_mat) - id += 1 - if id >= 51: - break else: break cap.release() @@ -259,23 +284,25 @@ def main(argv): """ Entrypoint of the script """ - if len(argv) < 3: - print('Usage: python infer.py /path/to/model/ /path/to/video') - return + model_dir = args.model_dir + use_gpu = args.use_gpu - model_dir = sys.argv[1] - input_path = sys.argv[2] - use_gpu = int(sys.argv[3]) if len(sys.argv) >= 4 else 0 # Init model mean = [104.008, 116.669, 122.675] scale = [1.0, 1.0, 1.0] eval_size = (192, 192) seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) - # Run Predicting on a video and result will be saved as result.avi - #predict_camera(seg) - predict_video(seg, input_path) - #predict_image(seg, input_path) + if args.use_camera: + # if enable input video stream from video + predict_camera(seg) + elif args.video_path: + # if video_path valid, do predicting on video + predict_video(seg, args.video_path) + elif args.img_path: + # if img_path valid, do predicting on the image + predict_image(seg, args.img_path) if __name__ == "__main__": - main(sys.argv) + argv = parse_args() + main(argv) -- GitLab