diff --git a/contrib/RealTimeHumanSeg/CMakeLists.txt b/contrib/RealTimeHumanSeg/cpp/CMakeLists.txt similarity index 95% rename from contrib/RealTimeHumanSeg/CMakeLists.txt rename to contrib/RealTimeHumanSeg/cpp/CMakeLists.txt index 992b6d1d9d8ae2c801b27bfda33ab2d89bf0d147..5a7b89acc41da5576a0f0ead7205385feabf5dab 100644 --- a/contrib/RealTimeHumanSeg/CMakeLists.txt +++ b/contrib/RealTimeHumanSeg/cpp/CMakeLists.txt @@ -70,10 +70,10 @@ if (WIN32) include_directories("${OPENCV_DIR}/opencv/build/include") link_directories("${OPENCV_DIR}/build/x64/vc14/lib") else () + find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH) include_directories("${PADDLE_DIR}/paddle/include") link_directories("${PADDLE_DIR}/paddle/lib") - include_directories("${OPENCV_DIR}/include") - link_directories("${OPENCV_DIR}/lib") + include_directories(${OpenCV_INCLUDE_DIRS}) endif () if (WIN32) @@ -202,12 +202,8 @@ if(WITH_GPU) endif() if (NOT WIN32) - set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread" - "-lopencv_world -lopencv_img_hash" - "-lIlmImf -llibpng -lippiw -lippicv" - "-llibtiff -llibwebp -littnotify -llibjasper" - "-llibjpeg -lzlib") - set(DEPS ${DEPS} ${EXTERNAL_LIB}) + set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread") + set(DEPS ${DEPS} ${EXTERNAL_LIB} ${OpenCV_LIBS}) endif() add_executable(main main.cc humanseg.cc humanseg_postprocess.cc) diff --git a/contrib/RealTimeHumanSeg/CMakeSettings.json b/contrib/RealTimeHumanSeg/cpp/CMakeSettings.json similarity index 100% rename from contrib/RealTimeHumanSeg/CMakeSettings.json rename to contrib/RealTimeHumanSeg/cpp/CMakeSettings.json diff --git a/contrib/RealTimeHumanSeg/README.md b/contrib/RealTimeHumanSeg/cpp/README.md similarity index 100% rename from contrib/RealTimeHumanSeg/README.md rename to contrib/RealTimeHumanSeg/cpp/README.md diff --git a/contrib/RealTimeHumanSeg/docs/linux_build.md b/contrib/RealTimeHumanSeg/cpp/docs/linux_build.md similarity index 94% rename from contrib/RealTimeHumanSeg/docs/linux_build.md rename to contrib/RealTimeHumanSeg/cpp/docs/linux_build.md index 2f7e5fdb0e42cbe691876116ba7ed91923fdf636..823ff3ae7cc6b16d9f5696924ae5def746bc8892 100644 --- a/contrib/RealTimeHumanSeg/docs/linux_build.md +++ b/contrib/RealTimeHumanSeg/cpp/docs/linux_build.md @@ -80,3 +80,7 @@ sh linux_build.sh ```shell ./build/main ./models /PATH/TO/TEST_VIDEO ``` + +点击下载[测试视频](https://paddleseg.bj.bcebos.com/deploy/data/test.avi) + +预测的结果保存在视频文件`result.avi`中。 diff --git a/contrib/RealTimeHumanSeg/docs/windows_build.md b/contrib/RealTimeHumanSeg/cpp/docs/windows_build.md similarity index 97% rename from contrib/RealTimeHumanSeg/docs/windows_build.md rename to contrib/RealTimeHumanSeg/cpp/docs/windows_build.md index 89b28a3fdc8e95b9c78fbb2a358a74c6e6f261e0..6937dbcff4f55c5a085aa9d0bd2674c04f3ac8e5 100644 --- a/contrib/RealTimeHumanSeg/docs/windows_build.md +++ b/contrib/RealTimeHumanSeg/cpp/docs/windows_build.md @@ -78,4 +78,6 @@ main.exe ./models/ ./data/test.avi ``` 第一个参数即人像分割预测模型的路径,第二个参数即要预测的视频。 +点击下载[测试视频](https://paddleseg.bj.bcebos.com/deploy/data/test.avi) + 运行后,预测结果保存在文件`result.avi`中。 diff --git a/contrib/RealTimeHumanSeg/humanseg.cc b/contrib/RealTimeHumanSeg/cpp/humanseg.cc similarity index 94% rename from contrib/RealTimeHumanSeg/humanseg.cc rename to contrib/RealTimeHumanSeg/cpp/humanseg.cc index 988742bee10f67f696283d5a63d5223ccc14449d..b81c81200064f6191e18cdb39fc8d6414aa5fe9d 100644 --- a/contrib/RealTimeHumanSeg/humanseg.cc +++ b/contrib/RealTimeHumanSeg/cpp/humanseg.cc @@ -44,7 +44,9 @@ void LoadModel( std::unique_ptr* predictor) { // Config the model info paddle::AnalysisConfig config; - config.SetModel(model_dir); + auto prog_file = model_dir + "/__model__"; + auto params_file = model_dir + "/__params__"; + config.SetModel(prog_file, params_file); if (use_gpu) { config.EnableUseGpu(100, 0); } else { @@ -60,7 +62,8 @@ void LoadModel( void HumanSeg::Preprocess(const cv::Mat& image_mat) { // Clone the image : keep the original mat for postprocess cv::Mat im = image_mat.clone(); - cv::resize(im, im, cv::Size(192, 192), 0.f, 0.f, cv::INTER_LINEAR); + auto eval_wh = cv::Size(eval_size_[0], eval_size_[1]); + cv::resize(im, im, eval_wh, 0.f, 0.f, cv::INTER_LINEAR); im.convertTo(im, CV_32FC3, 1.0); int rc = im.channels(); diff --git a/contrib/RealTimeHumanSeg/humanseg.h b/contrib/RealTimeHumanSeg/cpp/humanseg.h similarity index 92% rename from contrib/RealTimeHumanSeg/humanseg.h rename to contrib/RealTimeHumanSeg/cpp/humanseg.h index 9a4223460a71ceed70e25acf31636cd9b89f7dce..edaf825f713847a3b2c8bf5bae3a36de6ec03395 100644 --- a/contrib/RealTimeHumanSeg/humanseg.h +++ b/contrib/RealTimeHumanSeg/cpp/humanseg.h @@ -37,9 +37,11 @@ class HumanSeg { explicit HumanSeg(const std::string& model_dir, const std::vector& mean, const std::vector& scale, + const std::vector& eval_size, bool use_gpu = false) : mean_(mean), - scale_(scale) { + scale_(scale), + eval_size_(eval_size) { LoadModel(model_dir, use_gpu, &predictor_); } @@ -60,4 +62,5 @@ class HumanSeg { std::vector segout_data_; std::vector mean_; std::vector scale_; + std::vector eval_size_; }; diff --git a/contrib/RealTimeHumanSeg/humanseg_postprocess.cc b/contrib/RealTimeHumanSeg/cpp/humanseg_postprocess.cc similarity index 100% rename from contrib/RealTimeHumanSeg/humanseg_postprocess.cc rename to contrib/RealTimeHumanSeg/cpp/humanseg_postprocess.cc diff --git a/contrib/RealTimeHumanSeg/humanseg_postprocess.h b/contrib/RealTimeHumanSeg/cpp/humanseg_postprocess.h similarity index 100% rename from contrib/RealTimeHumanSeg/humanseg_postprocess.h rename to contrib/RealTimeHumanSeg/cpp/humanseg_postprocess.h diff --git a/contrib/RealTimeHumanSeg/linux_build.sh b/contrib/RealTimeHumanSeg/cpp/linux_build.sh similarity index 68% rename from contrib/RealTimeHumanSeg/linux_build.sh rename to contrib/RealTimeHumanSeg/cpp/linux_build.sh index a0382fb9330fcd74bd719640313faba51e0019ca..ff0b11bcf60f1b4ec4d7a9f63f7490ffb70ad6e0 100644 --- a/contrib/RealTimeHumanSeg/linux_build.sh +++ b/contrib/RealTimeHumanSeg/cpp/linux_build.sh @@ -1,10 +1,10 @@ -OPENCV_URL=https://paddleseg.bj.bcebos.com/deploy/deps/opencv341.tar.bz2 -if [ ! -d "./deps/opencv341" ]; then +OPENCV_URL=https://paddleseg.bj.bcebos.com/deploy/deps/opencv346.tar.bz2 +if [ ! -d "./deps/opencv346" ]; then mkdir -p deps cd deps wget -c ${OPENCV_URL} - tar xvfj opencv341.tar.bz2 - rm -rf opencv341.tar.bz2 + tar xvfj opencv346.tar.bz2 + rm -rf opencv346.tar.bz2 cd .. fi @@ -12,7 +12,8 @@ WITH_GPU=OFF PADDLE_DIR=/root/projects/deps/fluid_inference/ CUDA_LIB=/usr/local/cuda-10.0/lib64/ CUDNN_LIB=/usr/local/cuda-10.0/lib64/ -OPENCV_DIR=$(pwd)/deps/opencv341/ +OPENCV_DIR=$(pwd)/deps/opencv346/ +echo ${OPENCV_DIR} rm -rf build mkdir -p build diff --git a/contrib/RealTimeHumanSeg/main.cc b/contrib/RealTimeHumanSeg/cpp/main.cc similarity index 96% rename from contrib/RealTimeHumanSeg/main.cc rename to contrib/RealTimeHumanSeg/cpp/main.cc index fea3a548dd1ef556066d5985c114226f6da6dca5..303051f051b885a83b0ef608fe2ab1319f97294e 100644 --- a/contrib/RealTimeHumanSeg/main.cc +++ b/contrib/RealTimeHumanSeg/cpp/main.cc @@ -78,7 +78,8 @@ int main(int argc, char* argv[]) { // Init Model std::vector means = {104.008, 116.669, 122.675}; std::vector scale = {1.000, 1.000, 1.000}; - HumanSeg seg(model_dir, means, scale, use_gpu); + std::vector eval_sz = {192, 192}; + HumanSeg seg(model_dir, means, scale, eval_sz, use_gpu); // Call ImagePredict while input_path is a image file path // The output will be saved as result.jpeg diff --git a/contrib/RealTimeHumanSeg/python/infer.py b/contrib/RealTimeHumanSeg/python/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..dc818249b91ad8c41616022370f2df3db989bdcb --- /dev/null +++ b/contrib/RealTimeHumanSeg/python/infer.py @@ -0,0 +1,370 @@ +# coding: utf8 +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python Inference solution for realtime humansegmentation""" + +import os +import argparse +import numpy as np +import cv2 + +import paddle.fluid as fluid + + +def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): + """Optical flow tracking for human segmentation + Args: + pre_gray: Grayscale of previous frame. + cur_gray: Grayscale of current frame. + prev_cfd: Optical flow of previous frame. + dl_weights: Merged weights data. + disflow: A data structure represents optical flow. + Returns: + is_track: Binary graph, whethe a pixel matched with a optical flow point. + track_cfd: tracking optical flow image. + """ + check_thres = 8 + hgt, wdh = pre_gray.shape[:2] + track_cfd = np.zeros_like(prev_cfd) + is_track = np.zeros_like(pre_gray) + # compute forward optical flow + flow_fw = disflow.calc(pre_gray, cur_gray, None) + # compute backword optical flow + flow_bw = disflow.calc(cur_gray, pre_gray, None) + get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5) + for row in range(hgt): + for col in range(wdh): + # Calculate new coordinate after optfow process. + # (row, col) -> (cur_x, cur_y) + fxy_fw = flow_fw[row, col] + dx_fw = get_round(fxy_fw[0]) + cur_x = dx_fw + col + dy_fw = get_round(fxy_fw[1]) + cur_y = dy_fw + row + if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt: + continue + fxy_bw = flow_bw[cur_y, cur_x] + dx_bw = get_round(fxy_bw[0]) + dy_bw = get_round(fxy_bw[1]) + # Filt the Optical flow point with a threshold + lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) + if lmt >= check_thres: + continue + # Downgrade still points + if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0: + dl_weights[cur_y, cur_x] = 0.05 + is_track[cur_y, cur_x] = 1 + track_cfd[cur_y, cur_x] = prev_cfd[row, col] + return track_cfd, is_track, dl_weights + + +def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): + """Fusion of Optical flow track and segmentation + Args: + track_cfd: Optical flow track. + dl_cfd: Segmentation result of current frame. + dl_weights: Merged weights data. + is_track: Binary graph, whethe a pixel matched with a optical flow point. + Returns: + cur_cfd: Fusion of Optical flow track and segmentation result. + """ + cur_cfd = dl_cfd.copy() + idxs = np.where(is_track > 0) + for i in range(len(idxs)): + x, y = idxs[0][i], idxs[1][i] + dl_score = dl_cfd[y, x] + track_score = track_cfd[y, x] + if dl_score > 0.9 or dl_score < 0.1: + if dl_weights[x, y] < 0.1: + cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score + else: + cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score + else: + cur_cfd[x, y] = dl_weights[x, y] * dl_score + (1 - dl_weights[x, y]) * track_score + return cur_cfd + + +def threshold_mask(img, thresh_bg, thresh_fg): + """Threshold mask for image foreground and background + Args: + img : Original image, an instance of np.uint8 array. + thresh_bg : Threshold for background, set to 0 when less than it. + thresh_fg : Threshold for foreground, set to 1 when greater than it. + Returns: + dst : Image after set thresthold mask, ans instance of np.float32 array. + """ + dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg) + dst[np.where(dst > 1)] = 1 + dst[np.where(dst < 0)] = 0 + return dst.astype(np.float32) + + +def optflow_handle(cur_gray, scoremap, is_init): + """Processing optical flow and segmentation result. + Args: + cur_gray : Grayscale of current frame. + scoremap : Segmentation result of current frame. + is_init : True only when process the first frame of a video. + Returns: + dst : Image after set thresthold mask, ans instance of np.float32 array. + """ + width, height = scoremap.shape[0], scoremap.shape[1] + disflow = cv2.DISOpticalFlow_create( + cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) + prev_gray = np.zeros((height, width), np.uint8) + prev_cfd = np.zeros((height, width), np.float32) + cur_cfd = scoremap.copy() + if is_init: + is_init = False + if height <= 64 or width <= 64: + disflow.setFinestScale(1) + elif height <= 160 or width <= 160: + disflow.setFinestScale(2) + else: + disflow.setFinestScale(3) + fusion_cfd = cur_cfd + else: + weights = np.ones((width, height), np.float32) * 0.3 + track_cfd, is_track, weights = human_seg_tracking( + prev_gray, cur_gray, prev_cfd, weights, disflow) + fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track) + fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0) + return fusion_cfd + + +class HumanSeg: + """Human Segmentation Class + This Class instance will load the inference model and do inference + on input image object. + + It includes the key stages for a object segmentation inference task. + Call run_predict on your image and it will return a processed image. + """ + def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False): + + self.mean = np.array(mean).reshape((3, 1, 1)) + self.scale = np.array(scale).reshape((3, 1, 1)) + self.eval_size = eval_size + self.load_model(model_dir, use_gpu) + + def load_model(self, model_dir, use_gpu): + """Load paddle inference model. + Args: + model_dir: The inference model path includes `__model__` and `__params__`. + use_gpu: Enable gpu if use_gpu is True + """ + prog_file = os.path.join(model_dir, '__model__') + params_file = os.path.join(model_dir, '__params__') + config = fluid.core.AnalysisConfig(prog_file, params_file) + if use_gpu: + config.enable_use_gpu(100, 0) + config.switch_ir_optim(True) + else: + config.disable_gpu() + config.disable_glog_info() + config.switch_specify_input_names(True) + config.enable_memory_optim() + self.predictor = fluid.core.create_paddle_predictor(config) + + def preprocess(self, image): + """Preprocess input image. + Convert hwc_rgb to chw_bgr. + Args: + image: The input opencv image object. + Returns: + A preprocessed image object. + """ + img_mat = cv2.resize( + image, self.eval_size, interpolation=cv2.INTER_LINEAR) + # HWC -> CHW + img_mat = img_mat.swapaxes(1, 2) + img_mat = img_mat.swapaxes(0, 1) + # Convert to float + img_mat = img_mat[:, :, :].astype('float32') + # img_mat = (img_mat - mean) * scale + img_mat = img_mat - self.mean + img_mat = img_mat * self.scale + img_mat = img_mat[np.newaxis, :, :, :] + return img_mat + + def postprocess(self, image, output_data): + """Postprocess the inference result and original input image. + Args: + image: The original opencv image object. + output_data: The inference output of paddle's humansegmentation model. + Returns: + The result merged original image and segmentation result with optical-flow improvement. + """ + scoremap = output_data[0, 1, :, :] + scoremap = (scoremap * 255).astype(np.uint8) + ori_h, ori_w = image.shape[0], image.shape[1] + evl_h, evl_w = self.eval_size[0], self.eval_size[1] + # optical flow processing + cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + cur_gray = cv2.resize(cur_gray, (evl_w, evl_h)) + optflow_map = optflow_handle(cur_gray, scoremap, False) + optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0) + optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8) + optflow_map = cv2.resize(optflow_map, (ori_w, ori_h)) + optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) + bg_im = np.ones_like(optflow_map) * 255 + comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(np.uint8) + return comb + + def run_predict(self, image): + """Run Predicting on an opencv image object. + Preprocess the image, do inference, and then postprocess the infering output. + Args: + image: A valid opencv image object. + Returns: + The segmentation result which represents as an opencv image object. + """ + im_mat = self.preprocess(image) + im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32')) + output_data = self.predictor.run([im_tensor])[0] + output_data = output_data.as_ndarray() + return self.postprocess(image, output_data) + + +def predict_image(seg, image_path): + """Do Predicting on a image file. + Decoding the image file and do predicting on it. + The result will be saved as `result.jpeg`. + Args: + seg: The HumanSeg Object which holds a inference model. + Do preprocessing / predicting / postprocessing on a input image object. + image_path: Path of the image file needs to be processed. + """ + img_mat = cv2.imread(image_path) + img_mat = seg.run_predict(img_mat) + cv2.imwrite('result.jpeg', img_mat) + + +def predict_video(seg, video_path): + """Do Predicting on a video file. + Decoding the video file and do predicting on each frame. + All result will be saved as `result.avi`. + Args: + seg: The HumanSeg Object which holds a inference model. + Do preprocessing / predicting / postprocessing on a input image object. + video_path: Path of a video file needs to be processed. + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print("Error opening video stream or file") + return + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + # Result Video Writer + out = cv2.VideoWriter('result.avi', + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, + (width, height)) + # Start capturing from video + while cap.isOpened(): + ret, frame = cap.read() + if ret: + img_mat = seg.run_predict(frame) + out.write(img_mat) + else: + break + cap.release() + out.release() + + +def predict_camera(seg): + """Do Predicting on a camera video stream. + Capturing each video frame from camera and do predicting on it. + All result frames will be shown in a GUI window. + Args: + seg: The HumanSeg Object which holds a inference model. + Do preprocessing / predicting / postprocessing on a input image object. + """ + cap = cv2.VideoCapture(0) + if not cap.isOpened(): + print("Error opening video stream or file") + return + # Start capturing from video + while cap.isOpened(): + ret, frame = cap.read() + if ret: + img_mat = seg.run_predict(frame) + cv2.imshow('HumanSegmentation', img_mat) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + else: + break + cap.release() + + +def main(args): + """Real Entrypoint of the script. + Load the human segmentation inference model and do predicting on the input resource. + Support three types of input: camera stream / video file / image file. + Args: + args: The command-line args for inference model. + Open camera and do predicting on camera stream while `args.use_camera` is true. + Open the video file and do predicting on it while `args.video_path` is valid. + Open the image file and do predicting on it while `args.img_path` is valid. + """ + model_dir = args.model_dir + use_gpu = args.use_gpu + + # Init model + mean = [104.008, 116.669, 122.675] + scale = [1.0, 1.0, 1.0] + eval_size = (192, 192) + seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) + if args.use_camera: + # if enable input video stream from camera + predict_camera(seg) + elif args.video_path: + # if video_path valid, do predicting on the video + predict_video(seg, args.video_path) + elif args.img_path: + # if img_path valid, do predicting on the image + predict_image(seg, args.img_path) + + +def parse_args(): + """Parsing command-line argments + """ + parser = argparse.ArgumentParser('Realtime Human Segmentation') + parser.add_argument('--model_dir', + type=str, + default='', + help='path of human segmentation model') + parser.add_argument('--img_path', + type=str, + default='', + help='path of input image') + parser.add_argument('--video_path', + type=str, + default='', + help='path of input video') + parser.add_argument('--use_camera', + type=bool, + default=False, + help='input video stream from camera') + parser.add_argument('--use_gpu', + type=bool, + default=False, + help='enable gpu') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/contrib/RealTimeHumanSeg/python/requirements.txt b/contrib/RealTimeHumanSeg/python/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..953dae0cf5e2036ad093907b30ac9a3a10858d27 --- /dev/null +++ b/contrib/RealTimeHumanSeg/python/requirements.txt @@ -0,0 +1,2 @@ +opencv-python==4.1.2.30 +opencv-contrib-python==4.2.0.32