From f71ef43fac2b8ac9a44ad8086bc96c1b474501bd Mon Sep 17 00:00:00 2001 From: sjtubinlong Date: Thu, 2 Apr 2020 16:03:06 +0800 Subject: [PATCH] RealTimeHumanSeg: add comments --- contrib/RealTimeHumanSeg/python/infer.py | 169 ++++++++++------------- 1 file changed, 72 insertions(+), 97 deletions(-) diff --git a/contrib/RealTimeHumanSeg/python/infer.py b/contrib/RealTimeHumanSeg/python/infer.py index dc818249..73df081e 100644 --- a/contrib/RealTimeHumanSeg/python/infer.py +++ b/contrib/RealTimeHumanSeg/python/infer.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Python Inference solution for realtime humansegmentation""" +"""实时人像分割Python预测部署""" import os import argparse @@ -24,29 +24,29 @@ import paddle.fluid as fluid def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): - """Optical flow tracking for human segmentation - Args: - pre_gray: Grayscale of previous frame. - cur_gray: Grayscale of current frame. - prev_cfd: Optical flow of previous frame. - dl_weights: Merged weights data. - disflow: A data structure represents optical flow. - Returns: - is_track: Binary graph, whethe a pixel matched with a optical flow point. - track_cfd: tracking optical flow image. + """计算光流跟踪匹配点和光流图 + 输入参数: + pre_gray: 上一帧灰度图 + cur_gray: 当前帧灰度图 + prev_cfd: 上一帧光流图 + dl_weights: 融合权重图 + disflow: 光流数据结构 + 返回值: + is_track: 光流点跟踪二值图,即是否具有光流点匹配 + track_cfd: 光流跟踪图 """ check_thres = 8 hgt, wdh = pre_gray.shape[:2] track_cfd = np.zeros_like(prev_cfd) is_track = np.zeros_like(pre_gray) - # compute forward optical flow + # 计算前向光流 flow_fw = disflow.calc(pre_gray, cur_gray, None) - # compute backword optical flow + # 计算后向光流 flow_bw = disflow.calc(cur_gray, pre_gray, None) get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5) for row in range(hgt): for col in range(wdh): - # Calculate new coordinate after optfow process. + # 计算光流处理后对应点坐标 # (row, col) -> (cur_x, cur_y) fxy_fw = flow_fw[row, col] dx_fw = get_round(fxy_fw[0]) @@ -58,11 +58,11 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): fxy_bw = flow_bw[cur_y, cur_x] dx_bw = get_round(fxy_bw[0]) dy_bw = get_round(fxy_bw[1]) - # Filt the Optical flow point with a threshold + # 光流移动小于阈值 lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) if lmt >= check_thres: continue - # Downgrade still points + # 静止点降权 if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0: dl_weights[cur_y, cur_x] = 0.05 is_track[cur_y, cur_x] = 1 @@ -71,14 +71,14 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): - """Fusion of Optical flow track and segmentation - Args: - track_cfd: Optical flow track. - dl_cfd: Segmentation result of current frame. - dl_weights: Merged weights data. - is_track: Binary graph, whethe a pixel matched with a optical flow point. - Returns: - cur_cfd: Fusion of Optical flow track and segmentation result. + """光流追踪图和人像分割结构融合 + 输入参数: + track_cfd: 光流追踪图 + dl_cfd: 当前帧分割结果 + dl_weights: 融合权重图 + is_track: 光流点匹配二值图 + 返回值: + cur_cfd: 光流跟踪图和人像分割结果融合图 """ cur_cfd = dl_cfd.copy() idxs = np.where(is_track > 0) @@ -97,13 +97,13 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): def threshold_mask(img, thresh_bg, thresh_fg): - """Threshold mask for image foreground and background - Args: - img : Original image, an instance of np.uint8 array. - thresh_bg : Threshold for background, set to 0 when less than it. - thresh_fg : Threshold for foreground, set to 1 when greater than it. - Returns: - dst : Image after set thresthold mask, ans instance of np.float32 array. + """设置背景和前景阈值mask + 输入参数: + img : 原始图像, np.uint8 类型. + thresh_bg : 背景阈值百分比,低于该值置为0. + thresh_fg : 前景阈值百分比,超过该值置为1. + 返回值: + dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型. """ dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg) dst[np.where(dst > 1)] = 1 @@ -112,13 +112,13 @@ def threshold_mask(img, thresh_bg, thresh_fg): def optflow_handle(cur_gray, scoremap, is_init): - """Processing optical flow and segmentation result. + """光流优化 Args: - cur_gray : Grayscale of current frame. - scoremap : Segmentation result of current frame. - is_init : True only when process the first frame of a video. + cur_gray : 当前帧灰度图 + scoremap : 当前帧分割结果 + is_init : 是否第一帧 Returns: - dst : Image after set thresthold mask, ans instance of np.float32 array. + dst : 光流追踪图和预测结果融合图, 类型为 np.float32 """ width, height = scoremap.shape[0], scoremap.shape[1] disflow = cv2.DISOpticalFlow_create( @@ -145,12 +145,8 @@ def optflow_handle(cur_gray, scoremap, is_init): class HumanSeg: - """Human Segmentation Class - This Class instance will load the inference model and do inference - on input image object. - - It includes the key stages for a object segmentation inference task. - Call run_predict on your image and it will return a processed image. + """人像分割类 + 封装了人像分割模型的加载,数据预处理,预测,后处理等 """ def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False): @@ -160,10 +156,10 @@ class HumanSeg: self.load_model(model_dir, use_gpu) def load_model(self, model_dir, use_gpu): - """Load paddle inference model. + """加载模型并创建predictor Args: - model_dir: The inference model path includes `__model__` and `__params__`. - use_gpu: Enable gpu if use_gpu is True + model_dir: 预测模型路径, 包含 `__model__` 和 `__params__` + use_gpu: 是否使用GPU加速 """ prog_file = os.path.join(model_dir, '__model__') params_file = os.path.join(model_dir, '__params__') @@ -179,12 +175,12 @@ class HumanSeg: self.predictor = fluid.core.create_paddle_predictor(config) def preprocess(self, image): - """Preprocess input image. - Convert hwc_rgb to chw_bgr. - Args: - image: The input opencv image object. - Returns: - A preprocessed image object. + """图像预处理 + hwc_rgb 转换为 chw_bgr,并进行归一化 + 输入参数: + image: 原始图像 + 返回值: + 经过预处理后的图片结果 """ img_mat = cv2.resize( image, self.eval_size, interpolation=cv2.INTER_LINEAR) @@ -200,18 +196,18 @@ class HumanSeg: return img_mat def postprocess(self, image, output_data): - """Postprocess the inference result and original input image. + """对预测结果进行后处理 Args: - image: The original opencv image object. - output_data: The inference output of paddle's humansegmentation model. + image: 原始图,opencv 图片对象 + output_data: Paddle预测结果原始数据 Returns: - The result merged original image and segmentation result with optical-flow improvement. + 原图和预测结果融合并做了光流优化的结果图 """ scoremap = output_data[0, 1, :, :] scoremap = (scoremap * 255).astype(np.uint8) ori_h, ori_w = image.shape[0], image.shape[1] evl_h, evl_w = self.eval_size[0], self.eval_size[1] - # optical flow processing + # 光流处理 cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) cur_gray = cv2.resize(cur_gray, (evl_w, evl_h)) optflow_map = optflow_handle(cur_gray, scoremap, False) @@ -224,12 +220,11 @@ class HumanSeg: return comb def run_predict(self, image): - """Run Predicting on an opencv image object. - Preprocess the image, do inference, and then postprocess the infering output. - Args: - image: A valid opencv image object. - Returns: - The segmentation result which represents as an opencv image object. + """运行预测并返回可视化结果图 + 输入参数: + image: 需要预测的原始图, opencv图片对象 + 返回值: + 可视化的预测结果图 """ im_mat = self.preprocess(image) im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32')) @@ -239,13 +234,8 @@ class HumanSeg: def predict_image(seg, image_path): - """Do Predicting on a image file. - Decoding the image file and do predicting on it. - The result will be saved as `result.jpeg`. - Args: - seg: The HumanSeg Object which holds a inference model. - Do preprocessing / predicting / postprocessing on a input image object. - image_path: Path of the image file needs to be processed. + """对图片文件进行分割 + 结果保存到`result.jpeg`文件中 """ img_mat = cv2.imread(image_path) img_mat = seg.run_predict(img_mat) @@ -253,13 +243,8 @@ def predict_image(seg, image_path): def predict_video(seg, video_path): - """Do Predicting on a video file. - Decoding the video file and do predicting on each frame. - All result will be saved as `result.avi`. - Args: - seg: The HumanSeg Object which holds a inference model. - Do preprocessing / predicting / postprocessing on a input image object. - video_path: Path of a video file needs to be processed. + """对视频文件进行分割 + 结果保存到`result.avi`文件中 """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): @@ -268,11 +253,11 @@ def predict_video(seg, video_path): width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) - # Result Video Writer + # 用于保存预测结果视频 out = cv2.VideoWriter('result.avi', cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height)) - # Start capturing from video + # 开始获取视频帧 while cap.isOpened(): ret, frame = cap.read() if ret: @@ -285,12 +270,8 @@ def predict_video(seg, video_path): def predict_camera(seg): - """Do Predicting on a camera video stream. - Capturing each video frame from camera and do predicting on it. - All result frames will be shown in a GUI window. - Args: - seg: The HumanSeg Object which holds a inference model. - Do preprocessing / predicting / postprocessing on a input image object. + """从摄像头获取视频流进行预测 + 视频分割结果实时显示到可视化窗口中 """ cap = cv2.VideoCapture(0) if not cap.isOpened(): @@ -310,36 +291,30 @@ def predict_camera(seg): def main(args): - """Real Entrypoint of the script. - Load the human segmentation inference model and do predicting on the input resource. - Support three types of input: camera stream / video file / image file. - Args: - args: The command-line args for inference model. - Open camera and do predicting on camera stream while `args.use_camera` is true. - Open the video file and do predicting on it while `args.video_path` is valid. - Open the image file and do predicting on it while `args.img_path` is valid. + """预测程序入口 + 完成模型加载, 对视频、摄像头、图片文件等预测过程 """ model_dir = args.model_dir use_gpu = args.use_gpu - # Init model + # 加载模型 mean = [104.008, 116.669, 122.675] scale = [1.0, 1.0, 1.0] eval_size = (192, 192) seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) if args.use_camera: - # if enable input video stream from camera + # 开启摄像头 predict_camera(seg) elif args.video_path: - # if video_path valid, do predicting on the video + # 使用视频文件作为输入 predict_video(seg, args.video_path) elif args.img_path: - # if img_path valid, do predicting on the image + # 使用图片文件作为输入 predict_image(seg, args.img_path) def parse_args(): - """Parsing command-line argments + """解析命令行参数 """ parser = argparse.ArgumentParser('Realtime Human Segmentation') parser.add_argument('--model_dir', -- GitLab