diff --git a/contrib/HumanSeg/README.md b/contrib/HumanSeg/README.md
index a7cd03fdc1af69b3bf15d982c1966b573c938d71..a43413409bc59594e253b621b3358252c0fdf1e3 100644
--- a/contrib/HumanSeg/README.md
+++ b/contrib/HumanSeg/README.md
@@ -70,10 +70,30 @@ python video_infer.py --model_dir pretrained_weights/humanseg_lite_inference --v
+根据所选背景进行背景替换,背景可以是一张图片,也可以是一段视频。
+```bash
+# 通过电脑摄像头进行实时背景替换处理, 也可通过'--background_video_path'传入背景视频
+python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --background_image_path data/background.jpg
+
+# 对人像视频进行背景替换处理, 也可通过'--background_video_path'传入背景视频
+python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --video_path data/video_test.mp4 --background_image_path data/background.jpg
+
+# 对单张图像进行背景替换
+python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --image_path data/human_image.jpg --background_image_path data/background.jpg
+
+```
+
+背景替换结果如下:
+
+
+
+
**NOTE**:
视频分割处理时间需要几分钟,请耐心等待。
+提供的模型适用于手机摄像头竖屏拍摄场景,宽屏效果会略差一些。
+
## 训练
使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构`model_type`与模型参数`pretrained_weights`匹配。
```bash
diff --git a/contrib/HumanSeg/bg_replace.py b/contrib/HumanSeg/bg_replace.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbc6097b92aa788e6599a9366f4ea88b2df3946e
--- /dev/null
+++ b/contrib/HumanSeg/bg_replace.py
@@ -0,0 +1,286 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import os.path as osp
+import cv2
+import numpy as np
+
+from utils.humanseg_postprocess import postprocess, threshold_mask
+import models
+import transforms
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='HumanSeg inference for video')
+ parser.add_argument(
+ '--model_dir',
+ dest='model_dir',
+ help='Model path for inference',
+ type=str)
+ parser.add_argument(
+ '--image_path',
+ dest='image_path',
+ help='Image including human',
+ type=str,
+ default=None)
+ parser.add_argument(
+ '--background_image_path',
+ dest='background_image_path',
+ help='Background image for replacing',
+ type=str,
+ default=None)
+ parser.add_argument(
+ '--video_path',
+ dest='video_path',
+ help='Video path for inference',
+ type=str,
+ default=None)
+ parser.add_argument(
+ '--background_video_path',
+ dest='background_video_path',
+ help='Background video path for replacing',
+ type=str,
+ default=None)
+ parser.add_argument(
+ '--save_dir',
+ dest='save_dir',
+ help='The directory for saving the inference results',
+ type=str,
+ default='./output')
+ parser.add_argument(
+ "--image_shape",
+ dest="image_shape",
+ help="The image shape for net inputs.",
+ nargs=2,
+ default=[192, 192],
+ type=int)
+
+ return parser.parse_args()
+
+
+def predict(img, model, test_transforms):
+ model.arrange_transform(transforms=test_transforms, mode='test')
+ img, im_info = test_transforms(img)
+ img = np.expand_dims(img, axis=0)
+ result = model.exe.run(
+ model.test_prog,
+ feed={'image': img},
+ fetch_list=list(model.test_outputs.values()))
+ score_map = result[1]
+ score_map = np.squeeze(score_map, axis=0)
+ score_map = np.transpose(score_map, (1, 2, 0))
+ return score_map, im_info
+
+
+def recover(img, im_info):
+ keys = list(im_info.keys())
+ for k in keys[::-1]:
+ if k == 'shape_before_resize':
+ h, w = im_info[k][0], im_info[k][1]
+ img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
+ elif k == 'shape_before_padding':
+ h, w = im_info[k][0], im_info[k][1]
+ img = img[0:h, 0:w]
+ return img
+
+
+def bg_replace(score_map, img, bg):
+ h, w, _ = img.shape
+ bg = cv2.resize(bg, (w, h))
+ score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
+ comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
+ return comb
+
+
+def infer(args):
+ resize_h = args.image_shape[1]
+ resize_w = args.image_shape[0]
+
+ test_transforms = transforms.Compose(
+ [transforms.Resize((resize_w, resize_h)),
+ transforms.Normalize()])
+ model = models.load_model(args.model_dir)
+
+ if not osp.exists(args.save_dir):
+ os.makedirs(args.save_dir)
+
+ # 图像背景替换
+ if args.image_path is not None:
+ if not osp.exists(args.image_path):
+ raise ('The --image_path is not existed: {}'.format(
+ args.image_path))
+ if args.background_image_path is None:
+ raise ('The --background_image_path is not set. Please set it')
+ else:
+ if not osp.exists(args.background_image_path):
+ raise ('The --background_image_path is not existed: {}'.format(
+ args.background_image_path))
+ img = cv2.imread(args.image_path)
+ score_map, im_info = predict(img, model, test_transforms)
+ score_map = score_map[:, :, 1]
+ score_map = recover(score_map, im_info)
+ bg = cv2.imread(args.background_image_path)
+ save_name = osp.basename(args.image_path)
+ save_path = osp.join(args.save_dir, save_name)
+ result = bg_replace(score_map, img, bg)
+ cv2.imwrite(save_path, result)
+
+ # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
+ else:
+ is_video_bg = False
+ if args.background_video_path is not None:
+ if not osp.exists(args.background_video_path):
+ raise ('The --background_video_path is not existed: {}'.format(
+ args.background_video_path))
+ is_video_bg = True
+ elif args.background_image_path is not None:
+ if not osp.exists(args.background_image_path):
+ raise ('The --background_image_path is not existed: {}'.format(
+ args.background_image_path))
+ else:
+ raise (
+ 'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
+ )
+
+ disflow = cv2.DISOpticalFlow_create(
+ cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
+ prev_gray = np.zeros((resize_h, resize_w), np.uint8)
+ prev_cfd = np.zeros((resize_h, resize_w), np.float32)
+ is_init = True
+ if args.video_path is not None:
+ print('Please waite. It is computing......')
+ if not osp.exists(args.video_path):
+ raise ('The --video_path is not existed: {}'.format(
+ args.video_path))
+
+ cap_video = cv2.VideoCapture(args.video_path)
+ fps = cap_video.get(cv2.CAP_PROP_FPS)
+ width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ save_name = osp.basename(args.video_path)
+ save_name = save_name.split('.')[0]
+ save_path = osp.join(args.save_dir, save_name + '.avi')
+
+ cap_out = cv2.VideoWriter(
+ save_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
+ (width, height))
+
+ if is_video_bg:
+ cap_bg = cv2.VideoCapture(args.background_video_path)
+ frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
+ current_frame_bg = 1
+ else:
+ img_bg = cv2.imread(args.background_image_path)
+ while cap_video.isOpened():
+ ret, frame = cap_video.read()
+ if ret:
+ score_map, im_info = predict(frame, model, test_transforms)
+ cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
+ score_map = 255 * score_map[:, :, 1]
+ optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
+ disflow, is_init)
+ prev_gray = cur_gray.copy()
+ prev_cfd = optflow_map.copy()
+ is_init = False
+ optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
+ optflow_map = threshold_mask(
+ optflow_map, thresh_bg=0.2, thresh_fg=0.8)
+ score_map = recover(optflow_map, im_info)
+
+ #循环读取背景帧
+ if is_video_bg:
+ ret_bg, frame_bg = cap_bg.read()
+ if ret_bg:
+ if current_frame_bg == frames_bg:
+ current_frame_bg = 1
+ cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
+ else:
+ break
+ current_frame_bg += 1
+ comb = bg_replace(score_map, frame, frame_bg)
+ else:
+ comb = bg_replace(score_map, frame, img_bg)
+
+ cap_out.write(comb)
+ else:
+ break
+
+ if is_video_bg:
+ cap_bg.release()
+ cap_video.release()
+ cap_out.release()
+
+ # 当没有输入预测图像和视频的时候,则打开摄像头
+ else:
+ cap_video = cv2.VideoCapture(0)
+ if not cap_video.isOpened():
+ raise IOError("Error opening video stream or file, "
+ "--video_path whether existing: {}"
+ " or camera whether working".format(
+ args.video_path))
+ return
+
+ if is_video_bg:
+ cap_bg = cv2.VideoCapture(args.background_video_path)
+ frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
+ current_frame_bg = 1
+ else:
+ img_bg = cv2.imread(args.background_image_path)
+ while cap_video.isOpened():
+ ret, frame = cap_video.read()
+ if ret:
+ score_map, im_info = predict(frame, model, test_transforms)
+ cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
+ score_map = 255 * score_map[:, :, 1]
+ optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
+ disflow, is_init)
+ prev_gray = cur_gray.copy()
+ prev_cfd = optflow_map.copy()
+ is_init = False
+ optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
+ optflow_map = threshold_mask(
+ optflow_map, thresh_bg=0.2, thresh_fg=0.8)
+ score_map = recover(optflow_map, im_info)
+
+ #循环读取背景帧
+ if is_video_bg:
+ ret_bg, frame_bg = cap_bg.read()
+ if ret_bg:
+ if current_frame_bg == frames_bg:
+ current_frame_bg = 1
+ cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
+ else:
+ break
+ current_frame_bg += 1
+ comb = bg_replace(score_map, frame, frame_bg)
+ else:
+ comb = bg_replace(score_map, frame, img_bg)
+ cv2.imshow('HumanSegmentation', comb)
+ if cv2.waitKey(1) & 0xFF == ord('q'):
+ break
+ else:
+ break
+ if is_video_bg:
+ cap_bg.release()
+ cap_video.release()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/contrib/HumanSeg/data/background.jpg b/contrib/HumanSeg/data/background.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..792e43c2352bd4380e80c2d7541f2ea53f6bcc38
Binary files /dev/null and b/contrib/HumanSeg/data/background.jpg differ
diff --git a/contrib/HumanSeg/data/human_image.jpg b/contrib/HumanSeg/data/human_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d1cfb43e21c0d2822d5d87e58fe22921fd70728d
Binary files /dev/null and b/contrib/HumanSeg/data/human_image.jpg differ
diff --git a/contrib/HumanSeg/utils/humanseg_postprocess.py b/contrib/HumanSeg/utils/humanseg_postprocess.py
index d4624541a74bfd6ec4e99dda225096f975f75ee6..cd4d18da65cd6ccc68a9d727358da40588260a7d 100644
--- a/contrib/HumanSeg/utils/humanseg_postprocess.py
+++ b/contrib/HumanSeg/utils/humanseg_postprocess.py
@@ -14,13 +14,6 @@
# limitations under the License.
import numpy as np
-import cv2
-import os
-
-
-def get_round(data):
- round = 0.5 if data >= 0 else -0.5
- return (int)(data + round)
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
@@ -41,26 +34,28 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
is_track = np.zeros_like(pre_gray)
flow_fw = disflow.calc(pre_gray, cur_gray, None)
flow_bw = disflow.calc(cur_gray, pre_gray, None)
- for r in range(h):
- for c in range(w):
- fxy_fw = flow_fw[r, c]
- dx_fw = get_round(fxy_fw[0])
- cur_x = dx_fw + c
- dy_fw = get_round(fxy_fw[1])
- cur_y = dy_fw + r
- if cur_x < 0 or cur_x >= w or cur_y < 0 or cur_y >= h:
- continue
- fxy_bw = flow_bw[cur_y, cur_x]
- dx_bw = get_round(fxy_bw[0])
- dy_bw = get_round(fxy_bw[1])
- if ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
- (dx_fw + dx_bw) * (dx_fw + dx_bw)) >= check_thres:
- continue
- if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
- dx_bw) <= 0:
- dl_weights[cur_y, cur_x] = 0.05
- is_track[cur_y, cur_x] = 1
- track_cfd[cur_y, cur_x] = prev_cfd[r, c]
+ flow_fw = np.round(flow_fw).astype(np.int)
+ flow_bw = np.round(flow_bw).astype(np.int)
+ y_list = np.array(range(h))
+ x_list = np.array(range(w))
+ yv, xv = np.meshgrid(y_list, x_list)
+ yv, xv = yv.T, xv.T
+ cur_x = xv + flow_fw[:, :, 0]
+ cur_y = yv + flow_fw[:, :, 1]
+
+ # 超出边界不跟踪
+ not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h)
+ flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]]
+ not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) +
+ np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres
+ track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track]
+
+ is_track[cur_y[~not_track], cur_x[~not_track]] = 1
+
+ not_flow = np.all(
+ np.abs(flow_fw) == 0, axis=-1) * np.all(
+ np.abs(flow_bw) == 0, axis=-1)
+ dl_weights[cur_y[not_flow], cur_x[not_flow]] = 0.05
return track_cfd, is_track, dl_weights
@@ -75,24 +70,27 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
fusion_cfd = dl_cfd.copy()
- idxs = np.where(is_track > 0)
- for i in range(len(idxs[0])):
- x, y = idxs[0][i], idxs[1][i]
- dl_score = dl_cfd[x, y]
- track_score = track_cfd[x, y]
- fusion_cfd[x, y] = dl_weights[x, y] * dl_score + (
- 1 - dl_weights[x, y]) * track_score
- if dl_score > 0.9 or dl_score < 0.1:
- if dl_weights[x, y] < 0.1:
- fusion_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
- else:
- fusion_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
- else:
- fusion_cfd[x, y] = dl_weights[x, y] * dl_score + (
- 1 - dl_weights[x, y]) * track_score
+ is_track = is_track.astype(np.bool)
+ fusion_cfd[is_track] = dl_weights[is_track] * dl_cfd[is_track] + (
+ 1 - dl_weights[is_track]) * track_cfd[is_track]
+ # 确定区域
+ index_certain = ((dl_cfd > 0.9) + (dl_cfd < 0.1)) * is_track
+ index_less01 = (dl_weights < 0.1) * index_certain
+ fusion_cfd[index_less01] = 0.3 * dl_cfd[index_less01] + 0.7 * track_cfd[
+ index_less01]
+ index_larger09 = (dl_weights >= 0.1) * index_certain
+ fusion_cfd[index_larger09] = 0.4 * dl_cfd[index_larger09] + 0.6 * track_cfd[
+ index_larger09]
return fusion_cfd
+def threshold_mask(img, thresh_bg, thresh_fg):
+ dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
+ dst[np.where(dst > 1)] = 1
+ dst[np.where(dst < 0)] = 0
+ return dst.astype(np.float32)
+
+
def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
"""光流优化
Args:
@@ -105,13 +103,10 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
Returns:
fusion_cfd : 光流追踪图和预测结果融合图
"""
- height, width = scoremap.shape[0], scoremap.shape[1]
- disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
h, w = scoremap.shape
cur_cfd = scoremap.copy()
if is_init:
- is_init = False
if h <= 64 or w <= 64:
disflow.setFinestScale(1)
elif h <= 160 or w <= 160:
@@ -120,18 +115,9 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
- weights = np.ones((w, h), np.float32) * 0.3
+ weights = np.ones((h, w), np.float32) * 0.3
track_cfd, is_track, weights = human_seg_tracking(
prev_gray, cur_gray, pre_cfd, weights, disflow)
fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track)
- fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
-
return fusion_cfd
-
-
-def threshold_mask(img, thresh_bg, thresh_fg):
- dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
- dst[np.where(dst > 1)] = 1
- dst[np.where(dst < 0)] = 0
- return dst.astype(np.float32)
diff --git a/contrib/HumanSeg/video_infer.py b/contrib/HumanSeg/video_infer.py
index f37809366d06ee4553bb10d848b3e387a18797bd..b170a0c205a0f89be68f671cf3c90c97600295c3 100644
--- a/contrib/HumanSeg/video_infer.py
+++ b/contrib/HumanSeg/video_infer.py
@@ -109,7 +109,7 @@ def video_infer(args):
fps = cap.get(cv2.CAP_PROP_FPS)
if args.video_path:
-
+ print('Please waite. It is computing......')
# 用于保存预测结果视频
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
@@ -123,8 +123,8 @@ def video_infer(args):
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
- scoremap = 255 * score_map[:, :, 1]
- optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \
+ score_map = 255 * score_map[:, :, 1]
+ optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
@@ -132,10 +132,11 @@ def video_infer(args):
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
- img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
- img_mat = recover(img_mat, im_info)
- bg_im = np.ones_like(img_mat) * 255
- comb = (img_mat * frame + (1 - img_mat) * bg_im).astype(
+ img_matting = np.repeat(
+ optflow_map[:, :, np.newaxis], 3, axis=2)
+ img_matting = recover(img_matting, im_info)
+ bg_im = np.ones_like(img_matting) * 255
+ comb = (img_matting * frame + (1 - img_matting) * bg_im).astype(
np.uint8)
out.write(comb)
else:
@@ -150,20 +151,20 @@ def video_infer(args):
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
- scoremap = 255 * score_map[:, :, 1]
- optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \
+ score_map = 255 * score_map[:, :, 1]
+ optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
- # optflow_map = optflow_map/255.0
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
- img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
- img_mat = recover(img_mat, im_info)
- bg_im = np.ones_like(img_mat) * 255
- comb = (img_mat * frame + (1 - img_mat) * bg_im).astype(
+ img_matting = np.repeat(
+ optflow_map[:, :, np.newaxis], 3, axis=2)
+ img_matting = recover(img_matting, im_info)
+ bg_im = np.ones_like(img_matting) * 255
+ comb = (img_matting * frame + (1 - img_matting) * bg_im).astype(
np.uint8)
cv2.imshow('HumanSegmentation', comb)
if cv2.waitKey(1) & 0xFF == ord('q'):