diff --git a/contrib/RealTimeHumanSeg/python/infer.py b/contrib/RealTimeHumanSeg/python/infer.py index 0ebc57b029051d208cd7398c576c34f9dbc8f78f..dc818249b91ad8c41616022370f2df3db989bdcb 100644 --- a/contrib/RealTimeHumanSeg/python/infer.py +++ b/contrib/RealTimeHumanSeg/python/infer.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================== +"""Python Inference solution for realtime humansegmentation""" import os import argparse @@ -21,54 +23,31 @@ import cv2 import paddle.fluid as fluid -def parse_args(): - """ - Parsing command argments - """ - parser = argparse.ArgumentParser('Realtime Human Segmentation') - parser.add_argument('--model_dir', - type=str, - default='', - help='path of human segmentation model') - parser.add_argument('--img_path', - type=str, - default='', - help='path of input image') - parser.add_argument('--video_path', - type=str, - default='', - help='path of input video') - parser.add_argument('--use_camera', - type=bool, - default=False, - help='input video stream from camera') - parser.add_argument('--use_gpu', - type=bool, - default=False, - help='enable gpu') - return parser.parse_args() - - -def get_round(data): - """ - get round of data - """ - rnd = 0.5 if data >= 0 else -0.5 - return (int)(data + rnd) - - def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): - """ - human segmentation tracking + """Optical flow tracking for human segmentation + Args: + pre_gray: Grayscale of previous frame. + cur_gray: Grayscale of current frame. + prev_cfd: Optical flow of previous frame. + dl_weights: Merged weights data. + disflow: A data structure represents optical flow. + Returns: + is_track: Binary graph, whethe a pixel matched with a optical flow point. + track_cfd: tracking optical flow image. """ check_thres = 8 hgt, wdh = pre_gray.shape[:2] track_cfd = np.zeros_like(prev_cfd) is_track = np.zeros_like(pre_gray) + # compute forward optical flow flow_fw = disflow.calc(pre_gray, cur_gray, None) + # compute backword optical flow flow_bw = disflow.calc(cur_gray, pre_gray, None) + get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5) for row in range(hgt): for col in range(wdh): + # Calculate new coordinate after optfow process. + # (row, col) -> (cur_x, cur_y) fxy_fw = flow_fw[row, col] dx_fw = get_round(fxy_fw[0]) cur_x = dx_fw + col @@ -79,20 +58,27 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): fxy_bw = flow_bw[cur_y, cur_x] dx_bw = get_round(fxy_bw[0]) dy_bw = get_round(fxy_bw[1]) + # Filt the Optical flow point with a threshold lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) if lmt >= check_thres: continue + # Downgrade still points if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0: dl_weights[cur_y, cur_x] = 0.05 is_track[cur_y, cur_x] = 1 track_cfd[cur_y, cur_x] = prev_cfd[row, col] - return track_cfd, is_track, dl_weights def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): - """ - human segmentation tracking fuse + """Fusion of Optical flow track and segmentation + Args: + track_cfd: Optical flow track. + dl_cfd: Segmentation result of current frame. + dl_weights: Merged weights data. + is_track: Binary graph, whethe a pixel matched with a optical flow point. + Returns: + cur_cfd: Fusion of Optical flow track and segmentation result. """ cur_cfd = dl_cfd.copy() idxs = np.where(is_track > 0) @@ -111,8 +97,13 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): def threshold_mask(img, thresh_bg, thresh_fg): - """ - threshold mask + """Threshold mask for image foreground and background + Args: + img : Original image, an instance of np.uint8 array. + thresh_bg : Threshold for background, set to 0 when less than it. + thresh_fg : Threshold for foreground, set to 1 when greater than it. + Returns: + dst : Image after set thresthold mask, ans instance of np.float32 array. """ dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg) dst[np.where(dst > 1)] = 1 @@ -121,8 +112,13 @@ def threshold_mask(img, thresh_bg, thresh_fg): def optflow_handle(cur_gray, scoremap, is_init): - """ - optical flow handling + """Processing optical flow and segmentation result. + Args: + cur_gray : Grayscale of current frame. + scoremap : Segmentation result of current frame. + is_init : True only when process the first frame of a video. + Returns: + dst : Image after set thresthold mask, ans instance of np.float32 array. """ width, height = scoremap.shape[0], scoremap.shape[1] disflow = cv2.DISOpticalFlow_create( @@ -149,18 +145,25 @@ def optflow_handle(cur_gray, scoremap, is_init): class HumanSeg: - """ - Human Segmentation Class + """Human Segmentation Class + This Class instance will load the inference model and do inference + on input image object. + + It includes the key stages for a object segmentation inference task. + Call run_predict on your image and it will return a processed image. """ def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False): + self.mean = np.array(mean).reshape((3, 1, 1)) self.scale = np.array(scale).reshape((3, 1, 1)) self.eval_size = eval_size self.load_model(model_dir, use_gpu) def load_model(self, model_dir, use_gpu): - """ - Load model from model_dir + """Load paddle inference model. + Args: + model_dir: The inference model path includes `__model__` and `__params__`. + use_gpu: Enable gpu if use_gpu is True """ prog_file = os.path.join(model_dir, '__model__') params_file = os.path.join(model_dir, '__params__') @@ -176,8 +179,12 @@ class HumanSeg: self.predictor = fluid.core.create_paddle_predictor(config) def preprocess(self, image): - """ - preprocess image: hwc_rgb to chw_bgr + """Preprocess input image. + Convert hwc_rgb to chw_bgr. + Args: + image: The input opencv image object. + Returns: + A preprocessed image object. """ img_mat = cv2.resize( image, self.eval_size, interpolation=cv2.INTER_LINEAR) @@ -193,8 +200,12 @@ class HumanSeg: return img_mat def postprocess(self, image, output_data): - """ - postprocess result: merge background with segmentation result + """Postprocess the inference result and original input image. + Args: + image: The original opencv image object. + output_data: The inference output of paddle's humansegmentation model. + Returns: + The result merged original image and segmentation result with optical-flow improvement. """ scoremap = output_data[0, 1, :, :] scoremap = (scoremap * 255).astype(np.uint8) @@ -213,8 +224,12 @@ class HumanSeg: return comb def run_predict(self, image): - """ - run predict: return segmentation image mat + """Run Predicting on an opencv image object. + Preprocess the image, do inference, and then postprocess the infering output. + Args: + image: A valid opencv image object. + Returns: + The segmentation result which represents as an opencv image object. """ im_mat = self.preprocess(image) im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32')) @@ -224,8 +239,13 @@ class HumanSeg: def predict_image(seg, image_path): - """ - Do Predicting on a single image + """Do Predicting on a image file. + Decoding the image file and do predicting on it. + The result will be saved as `result.jpeg`. + Args: + seg: The HumanSeg Object which holds a inference model. + Do preprocessing / predicting / postprocessing on a input image object. + image_path: Path of the image file needs to be processed. """ img_mat = cv2.imread(image_path) img_mat = seg.run_predict(img_mat) @@ -233,8 +253,13 @@ def predict_image(seg, image_path): def predict_video(seg, video_path): - """ - Do Predicting on a video + """Do Predicting on a video file. + Decoding the video file and do predicting on each frame. + All result will be saved as `result.avi`. + Args: + seg: The HumanSeg Object which holds a inference model. + Do preprocessing / predicting / postprocessing on a input image object. + video_path: Path of a video file needs to be processed. """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): @@ -260,8 +285,12 @@ def predict_video(seg, video_path): def predict_camera(seg): - """ - Do Predicting on a camera video stream: Press q to exit + """Do Predicting on a camera video stream. + Capturing each video frame from camera and do predicting on it. + All result frames will be shown in a GUI window. + Args: + seg: The HumanSeg Object which holds a inference model. + Do preprocessing / predicting / postprocessing on a input image object. """ cap = cv2.VideoCapture(0) if not cap.isOpened(): @@ -281,8 +310,14 @@ def predict_camera(seg): def main(args): - """ - Entrypoint of the script + """Real Entrypoint of the script. + Load the human segmentation inference model and do predicting on the input resource. + Support three types of input: camera stream / video file / image file. + Args: + args: The command-line args for inference model. + Open camera and do predicting on camera stream while `args.use_camera` is true. + Open the video file and do predicting on it while `args.video_path` is valid. + Open the image file and do predicting on it while `args.img_path` is valid. """ model_dir = args.model_dir use_gpu = args.use_gpu @@ -293,16 +328,43 @@ def main(args): eval_size = (192, 192) seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) if args.use_camera: - # if enable input video stream from video + # if enable input video stream from camera predict_camera(seg) elif args.video_path: - # if video_path valid, do predicting on video + # if video_path valid, do predicting on the video predict_video(seg, args.video_path) elif args.img_path: # if img_path valid, do predicting on the image predict_image(seg, args.img_path) +def parse_args(): + """Parsing command-line argments + """ + parser = argparse.ArgumentParser('Realtime Human Segmentation') + parser.add_argument('--model_dir', + type=str, + default='', + help='path of human segmentation model') + parser.add_argument('--img_path', + type=str, + default='', + help='path of input image') + parser.add_argument('--video_path', + type=str, + default='', + help='path of input video') + parser.add_argument('--use_camera', + type=bool, + default=False, + help='input video stream from camera') + parser.add_argument('--use_gpu', + type=bool, + default=False, + help='enable gpu') + return parser.parse_args() + + if __name__ == "__main__": args = parse_args() main(args)