diff --git a/contrib/HumanSeg/README.md b/contrib/HumanSeg/README.md index a7cd03fdc1af69b3bf15d982c1966b573c938d71..a43413409bc59594e253b621b3358252c0fdf1e3 100644 --- a/contrib/HumanSeg/README.md +++ b/contrib/HumanSeg/README.md @@ -70,10 +70,30 @@ python video_infer.py --model_dir pretrained_weights/humanseg_lite_inference --v +根据所选背景进行背景替换,背景可以是一张图片,也可以是一段视频。 +```bash +# 通过电脑摄像头进行实时背景替换处理, 也可通过'--background_video_path'传入背景视频 +python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --background_image_path data/background.jpg + +# 对人像视频进行背景替换处理, 也可通过'--background_video_path'传入背景视频 +python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --video_path data/video_test.mp4 --background_image_path data/background.jpg + +# 对单张图像进行背景替换 +python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --image_path data/human_image.jpg --background_image_path data/background.jpg + +``` + +背景替换结果如下: + + + + **NOTE**: 视频分割处理时间需要几分钟,请耐心等待。 +提供的模型适用于手机摄像头竖屏拍摄场景,宽屏效果会略差一些。 + ## 训练 使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构`model_type`与模型参数`pretrained_weights`匹配。 ```bash diff --git a/contrib/HumanSeg/bg_replace.py b/contrib/HumanSeg/bg_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc6097b92aa788e6599a9366f4ea88b2df3946e --- /dev/null +++ b/contrib/HumanSeg/bg_replace.py @@ -0,0 +1,286 @@ +# coding: utf8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import os.path as osp +import cv2 +import numpy as np + +from utils.humanseg_postprocess import postprocess, threshold_mask +import models +import transforms + + +def parse_args(): + parser = argparse.ArgumentParser(description='HumanSeg inference for video') + parser.add_argument( + '--model_dir', + dest='model_dir', + help='Model path for inference', + type=str) + parser.add_argument( + '--image_path', + dest='image_path', + help='Image including human', + type=str, + default=None) + parser.add_argument( + '--background_image_path', + dest='background_image_path', + help='Background image for replacing', + type=str, + default=None) + parser.add_argument( + '--video_path', + dest='video_path', + help='Video path for inference', + type=str, + default=None) + parser.add_argument( + '--background_video_path', + dest='background_video_path', + help='Background video path for replacing', + type=str, + default=None) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='The directory for saving the inference results', + type=str, + default='./output') + parser.add_argument( + "--image_shape", + dest="image_shape", + help="The image shape for net inputs.", + nargs=2, + default=[192, 192], + type=int) + + return parser.parse_args() + + +def predict(img, model, test_transforms): + model.arrange_transform(transforms=test_transforms, mode='test') + img, im_info = test_transforms(img) + img = np.expand_dims(img, axis=0) + result = model.exe.run( + model.test_prog, + feed={'image': img}, + fetch_list=list(model.test_outputs.values())) + score_map = result[1] + score_map = np.squeeze(score_map, axis=0) + score_map = np.transpose(score_map, (1, 2, 0)) + return score_map, im_info + + +def recover(img, im_info): + keys = list(im_info.keys()) + for k in keys[::-1]: + if k == 'shape_before_resize': + h, w = im_info[k][0], im_info[k][1] + img = cv2.resize(img, (w, h), cv2.INTER_LINEAR) + elif k == 'shape_before_padding': + h, w = im_info[k][0], im_info[k][1] + img = img[0:h, 0:w] + return img + + +def bg_replace(score_map, img, bg): + h, w, _ = img.shape + bg = cv2.resize(bg, (w, h)) + score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2) + comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8) + return comb + + +def infer(args): + resize_h = args.image_shape[1] + resize_w = args.image_shape[0] + + test_transforms = transforms.Compose( + [transforms.Resize((resize_w, resize_h)), + transforms.Normalize()]) + model = models.load_model(args.model_dir) + + if not osp.exists(args.save_dir): + os.makedirs(args.save_dir) + + # 图像背景替换 + if args.image_path is not None: + if not osp.exists(args.image_path): + raise ('The --image_path is not existed: {}'.format( + args.image_path)) + if args.background_image_path is None: + raise ('The --background_image_path is not set. Please set it') + else: + if not osp.exists(args.background_image_path): + raise ('The --background_image_path is not existed: {}'.format( + args.background_image_path)) + img = cv2.imread(args.image_path) + score_map, im_info = predict(img, model, test_transforms) + score_map = score_map[:, :, 1] + score_map = recover(score_map, im_info) + bg = cv2.imread(args.background_image_path) + save_name = osp.basename(args.image_path) + save_path = osp.join(args.save_dir, save_name) + result = bg_replace(score_map, img, bg) + cv2.imwrite(save_path, result) + + # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片 + else: + is_video_bg = False + if args.background_video_path is not None: + if not osp.exists(args.background_video_path): + raise ('The --background_video_path is not existed: {}'.format( + args.background_video_path)) + is_video_bg = True + elif args.background_image_path is not None: + if not osp.exists(args.background_image_path): + raise ('The --background_image_path is not existed: {}'.format( + args.background_image_path)) + else: + raise ( + 'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path' + ) + + disflow = cv2.DISOpticalFlow_create( + cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) + prev_gray = np.zeros((resize_h, resize_w), np.uint8) + prev_cfd = np.zeros((resize_h, resize_w), np.float32) + is_init = True + if args.video_path is not None: + print('Please waite. It is computing......') + if not osp.exists(args.video_path): + raise ('The --video_path is not existed: {}'.format( + args.video_path)) + + cap_video = cv2.VideoCapture(args.video_path) + fps = cap_video.get(cv2.CAP_PROP_FPS) + width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + save_name = osp.basename(args.video_path) + save_name = save_name.split('.')[0] + save_path = osp.join(args.save_dir, save_name + '.avi') + + cap_out = cv2.VideoWriter( + save_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, + (width, height)) + + if is_video_bg: + cap_bg = cv2.VideoCapture(args.background_video_path) + frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT) + current_frame_bg = 1 + else: + img_bg = cv2.imread(args.background_image_path) + while cap_video.isOpened(): + ret, frame = cap_video.read() + if ret: + score_map, im_info = predict(frame, model, test_transforms) + cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) + score_map = 255 * score_map[:, :, 1] + optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \ + disflow, is_init) + prev_gray = cur_gray.copy() + prev_cfd = optflow_map.copy() + is_init = False + optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0) + optflow_map = threshold_mask( + optflow_map, thresh_bg=0.2, thresh_fg=0.8) + score_map = recover(optflow_map, im_info) + + #循环读取背景帧 + if is_video_bg: + ret_bg, frame_bg = cap_bg.read() + if ret_bg: + if current_frame_bg == frames_bg: + current_frame_bg = 1 + cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0) + else: + break + current_frame_bg += 1 + comb = bg_replace(score_map, frame, frame_bg) + else: + comb = bg_replace(score_map, frame, img_bg) + + cap_out.write(comb) + else: + break + + if is_video_bg: + cap_bg.release() + cap_video.release() + cap_out.release() + + # 当没有输入预测图像和视频的时候,则打开摄像头 + else: + cap_video = cv2.VideoCapture(0) + if not cap_video.isOpened(): + raise IOError("Error opening video stream or file, " + "--video_path whether existing: {}" + " or camera whether working".format( + args.video_path)) + return + + if is_video_bg: + cap_bg = cv2.VideoCapture(args.background_video_path) + frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT) + current_frame_bg = 1 + else: + img_bg = cv2.imread(args.background_image_path) + while cap_video.isOpened(): + ret, frame = cap_video.read() + if ret: + score_map, im_info = predict(frame, model, test_transforms) + cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) + score_map = 255 * score_map[:, :, 1] + optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \ + disflow, is_init) + prev_gray = cur_gray.copy() + prev_cfd = optflow_map.copy() + is_init = False + optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0) + optflow_map = threshold_mask( + optflow_map, thresh_bg=0.2, thresh_fg=0.8) + score_map = recover(optflow_map, im_info) + + #循环读取背景帧 + if is_video_bg: + ret_bg, frame_bg = cap_bg.read() + if ret_bg: + if current_frame_bg == frames_bg: + current_frame_bg = 1 + cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0) + else: + break + current_frame_bg += 1 + comb = bg_replace(score_map, frame, frame_bg) + else: + comb = bg_replace(score_map, frame, img_bg) + cv2.imshow('HumanSegmentation', comb) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + else: + break + if is_video_bg: + cap_bg.release() + cap_video.release() + + +if __name__ == "__main__": + args = parse_args() + infer(args) diff --git a/contrib/HumanSeg/data/background.jpg b/contrib/HumanSeg/data/background.jpg new file mode 100644 index 0000000000000000000000000000000000000000..792e43c2352bd4380e80c2d7541f2ea53f6bcc38 Binary files /dev/null and b/contrib/HumanSeg/data/background.jpg differ diff --git a/contrib/HumanSeg/data/human_image.jpg b/contrib/HumanSeg/data/human_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d1cfb43e21c0d2822d5d87e58fe22921fd70728d Binary files /dev/null and b/contrib/HumanSeg/data/human_image.jpg differ diff --git a/contrib/HumanSeg/utils/humanseg_postprocess.py b/contrib/HumanSeg/utils/humanseg_postprocess.py index d4624541a74bfd6ec4e99dda225096f975f75ee6..cd4d18da65cd6ccc68a9d727358da40588260a7d 100644 --- a/contrib/HumanSeg/utils/humanseg_postprocess.py +++ b/contrib/HumanSeg/utils/humanseg_postprocess.py @@ -14,13 +14,6 @@ # limitations under the License. import numpy as np -import cv2 -import os - - -def get_round(data): - round = 0.5 if data >= 0 else -0.5 - return (int)(data + round) def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): @@ -41,26 +34,28 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): is_track = np.zeros_like(pre_gray) flow_fw = disflow.calc(pre_gray, cur_gray, None) flow_bw = disflow.calc(cur_gray, pre_gray, None) - for r in range(h): - for c in range(w): - fxy_fw = flow_fw[r, c] - dx_fw = get_round(fxy_fw[0]) - cur_x = dx_fw + c - dy_fw = get_round(fxy_fw[1]) - cur_y = dy_fw + r - if cur_x < 0 or cur_x >= w or cur_y < 0 or cur_y >= h: - continue - fxy_bw = flow_bw[cur_y, cur_x] - dx_bw = get_round(fxy_bw[0]) - dy_bw = get_round(fxy_bw[1]) - if ((dy_fw + dy_bw) * (dy_fw + dy_bw) + - (dx_fw + dx_bw) * (dx_fw + dx_bw)) >= check_thres: - continue - if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs( - dx_bw) <= 0: - dl_weights[cur_y, cur_x] = 0.05 - is_track[cur_y, cur_x] = 1 - track_cfd[cur_y, cur_x] = prev_cfd[r, c] + flow_fw = np.round(flow_fw).astype(np.int) + flow_bw = np.round(flow_bw).astype(np.int) + y_list = np.array(range(h)) + x_list = np.array(range(w)) + yv, xv = np.meshgrid(y_list, x_list) + yv, xv = yv.T, xv.T + cur_x = xv + flow_fw[:, :, 0] + cur_y = yv + flow_fw[:, :, 1] + + # 超出边界不跟踪 + not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h) + flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]] + not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) + + np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres + track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track] + + is_track[cur_y[~not_track], cur_x[~not_track]] = 1 + + not_flow = np.all( + np.abs(flow_fw) == 0, axis=-1) * np.all( + np.abs(flow_bw) == 0, axis=-1) + dl_weights[cur_y[not_flow], cur_x[not_flow]] = 0.05 return track_cfd, is_track, dl_weights @@ -75,24 +70,27 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): cur_cfd: 光流跟踪图和人像分割结果融合图 """ fusion_cfd = dl_cfd.copy() - idxs = np.where(is_track > 0) - for i in range(len(idxs[0])): - x, y = idxs[0][i], idxs[1][i] - dl_score = dl_cfd[x, y] - track_score = track_cfd[x, y] - fusion_cfd[x, y] = dl_weights[x, y] * dl_score + ( - 1 - dl_weights[x, y]) * track_score - if dl_score > 0.9 or dl_score < 0.1: - if dl_weights[x, y] < 0.1: - fusion_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score - else: - fusion_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score - else: - fusion_cfd[x, y] = dl_weights[x, y] * dl_score + ( - 1 - dl_weights[x, y]) * track_score + is_track = is_track.astype(np.bool) + fusion_cfd[is_track] = dl_weights[is_track] * dl_cfd[is_track] + ( + 1 - dl_weights[is_track]) * track_cfd[is_track] + # 确定区域 + index_certain = ((dl_cfd > 0.9) + (dl_cfd < 0.1)) * is_track + index_less01 = (dl_weights < 0.1) * index_certain + fusion_cfd[index_less01] = 0.3 * dl_cfd[index_less01] + 0.7 * track_cfd[ + index_less01] + index_larger09 = (dl_weights >= 0.1) * index_certain + fusion_cfd[index_larger09] = 0.4 * dl_cfd[index_larger09] + 0.6 * track_cfd[ + index_larger09] return fusion_cfd +def threshold_mask(img, thresh_bg, thresh_fg): + dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg) + dst[np.where(dst > 1)] = 1 + dst[np.where(dst < 0)] = 0 + return dst.astype(np.float32) + + def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init): """光流优化 Args: @@ -105,13 +103,10 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init): Returns: fusion_cfd : 光流追踪图和预测结果融合图 """ - height, width = scoremap.shape[0], scoremap.shape[1] - disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) h, w = scoremap.shape cur_cfd = scoremap.copy() if is_init: - is_init = False if h <= 64 or w <= 64: disflow.setFinestScale(1) elif h <= 160 or w <= 160: @@ -120,18 +115,9 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init): disflow.setFinestScale(3) fusion_cfd = cur_cfd else: - weights = np.ones((w, h), np.float32) * 0.3 + weights = np.ones((h, w), np.float32) * 0.3 track_cfd, is_track, weights = human_seg_tracking( prev_gray, cur_gray, pre_cfd, weights, disflow) fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track) - fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0) - return fusion_cfd - - -def threshold_mask(img, thresh_bg, thresh_fg): - dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg) - dst[np.where(dst > 1)] = 1 - dst[np.where(dst < 0)] = 0 - return dst.astype(np.float32) diff --git a/contrib/HumanSeg/video_infer.py b/contrib/HumanSeg/video_infer.py index f37809366d06ee4553bb10d848b3e387a18797bd..b170a0c205a0f89be68f671cf3c90c97600295c3 100644 --- a/contrib/HumanSeg/video_infer.py +++ b/contrib/HumanSeg/video_infer.py @@ -109,7 +109,7 @@ def video_infer(args): fps = cap.get(cv2.CAP_PROP_FPS) if args.video_path: - + print('Please waite. It is computing......') # 用于保存预测结果视频 if not osp.exists(args.save_dir): os.makedirs(args.save_dir) @@ -123,8 +123,8 @@ def video_infer(args): score_map, im_info = predict(frame, model, test_transforms) cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) - scoremap = 255 * score_map[:, :, 1] - optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \ + score_map = 255 * score_map[:, :, 1] + optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \ disflow, is_init) prev_gray = cur_gray.copy() prev_cfd = optflow_map.copy() @@ -132,10 +132,11 @@ def video_infer(args): optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0) optflow_map = threshold_mask( optflow_map, thresh_bg=0.2, thresh_fg=0.8) - img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) - img_mat = recover(img_mat, im_info) - bg_im = np.ones_like(img_mat) * 255 - comb = (img_mat * frame + (1 - img_mat) * bg_im).astype( + img_matting = np.repeat( + optflow_map[:, :, np.newaxis], 3, axis=2) + img_matting = recover(img_matting, im_info) + bg_im = np.ones_like(img_matting) * 255 + comb = (img_matting * frame + (1 - img_matting) * bg_im).astype( np.uint8) out.write(comb) else: @@ -150,20 +151,20 @@ def video_infer(args): score_map, im_info = predict(frame, model, test_transforms) cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) cur_gray = cv2.resize(cur_gray, (resize_w, resize_h)) - scoremap = 255 * score_map[:, :, 1] - optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \ + score_map = 255 * score_map[:, :, 1] + optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \ disflow, is_init) prev_gray = cur_gray.copy() prev_cfd = optflow_map.copy() is_init = False - # optflow_map = optflow_map/255.0 optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0) optflow_map = threshold_mask( optflow_map, thresh_bg=0.2, thresh_fg=0.8) - img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) - img_mat = recover(img_mat, im_info) - bg_im = np.ones_like(img_mat) * 255 - comb = (img_mat * frame + (1 - img_mat) * bg_im).astype( + img_matting = np.repeat( + optflow_map[:, :, np.newaxis], 3, axis=2) + img_matting = recover(img_matting, im_info) + bg_im = np.ones_like(img_matting) * 255 + comb = (img_matting * frame + (1 - img_matting) * bg_im).astype( np.uint8) cv2.imshow('HumanSegmentation', comb) if cv2.waitKey(1) & 0xFF == ord('q'):