提交 ccea568d 编写于 作者: S sjtubinlong

fix coding style

上级 b7d40a93
...@@ -14,20 +14,47 @@ ...@@ -14,20 +14,47 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import argparse
import numpy as np import numpy as np
import cv2 import cv2
import paddle.fluid as fluid import paddle.fluid as fluid
def parse_args():
"""
Parsing command argments
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument('--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument('--img_path',
type=str,
default='',
help='path of input image')
parser.add_argument('--video_path',
type=str,
default='',
help='path of input video')
parser.add_argument('--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument('--use_gpu',
type=bool,
default=False,
help='enable gpu')
return parser.parse_args()
def get_round(data): def get_round(data):
""" """
get round of data get round of data
""" """
round = 0.5 if data >= 0 else -0.5 rnd = 0.5 if data >= 0 else -0.5
return (int)(data + round) return (int)(data + rnd)
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
...@@ -35,29 +62,30 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): ...@@ -35,29 +62,30 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
human segmentation tracking human segmentation tracking
""" """
check_thres = 8 check_thres = 8
h, w = pre_gray.shape[:2] hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd) track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray) is_track = np.zeros_like(pre_gray)
flow_fw = disflow.calc(pre_gray, cur_gray, None) flow_fw = disflow.calc(pre_gray, cur_gray, None)
flow_bw = disflow.calc(cur_gray, pre_gray, None) flow_bw = disflow.calc(cur_gray, pre_gray, None)
for r in range(h): for row in range(hgt):
for c in range(w): for col in range(wdh):
fxy_fw = flow_fw[r, c] fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0]) dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + c cur_x = dx_fw + col
dy_fw = get_round(fxy_fw[1]) dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + r cur_y = dy_fw + row
if cur_x < 0 or cur_x >= w or cur_y < 0 or cur_y >= h: if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
continue continue
fxy_bw = flow_bw[cur_y, cur_x] fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0]) dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1]) dy_bw = get_round(fxy_bw[1])
if ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) >= check_thres: lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue continue
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0: if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05 dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1 is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[r, c] track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights return track_cfd, is_track, dl_weights
...@@ -78,7 +106,7 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): ...@@ -78,7 +106,7 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
else: else:
cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else: else:
cur_cfd[x, y] = dl_weights[x,y]*dl_score + (1-dl_weights[x,y])*track_score cur_cfd[x, y] = dl_weights[x, y] * dl_score + (1 - dl_weights[x, y]) * track_score
return cur_cfd return cur_cfd
...@@ -96,22 +124,23 @@ def optflow_handle(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init): ...@@ -96,22 +124,23 @@ def optflow_handle(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
""" """
optical flow handling optical flow handling
""" """
w, h = scoremap.shape[0], scoremap.shape[1] width, height = scoremap.shape[0], scoremap.shape[1]
cur_cfd = scoremap.copy() cur_cfd = scoremap.copy()
if is_init: if is_init:
is_init = False is_init = False
if h <= 64 or w <= 64: if height <= 64 or width <= 64:
disflow.setFinestScale(1) disflow.setFinestScale(1)
elif h <= 160 or w <= 160: elif height <= 160 or width <= 160:
disflow.setFinestScale(2) disflow.setFinestScale(2)
else: else:
disflow.setFinestScale(3) disflow.setFinestScale(3)
fusion_cfd = cur_cfd fusion_cfd = cur_cfd
else: else:
weights = np.ones((w,h), np.float32) * 0.3 weights = np.ones((width, height), np.float32) * 0.3
track_cfd, is_track, weights = human_seg_tracking(prev_gray, cur_gray, pre_cfd, weights, disflow) track_cfd, is_track, weights = human_seg_tracking(
prev_gray, cur_gray, pre_cfd, weights, disflow)
fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track) fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3,3), 0) fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd return fusion_cfd
...@@ -179,8 +208,8 @@ class HumanSeg: ...@@ -179,8 +208,8 @@ class HumanSeg:
optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8) optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
optflow_map = cv2.resize(optflow_map, (ori_w, ori_h)) optflow_map = cv2.resize(optflow_map, (ori_w, ori_h))
optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2) optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
bg = np.ones_like(optflow_map) * 255 bg_im = np.ones_like(optflow_map) * 255
comb = (optflow_map * image + (1 - optflow_map) * bg).astype(np.uint8) comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(np.uint8)
return comb return comb
def run_predict(self, image): def run_predict(self, image):
...@@ -218,16 +247,12 @@ def predict_video(seg, video_path): ...@@ -218,16 +247,12 @@ def predict_video(seg, video_path):
out = cv2.VideoWriter('result.avi', out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
(width, height)) (width, height))
id = 1
# Start capturing from video # Start capturing from video
while cap.isOpened(): while cap.isOpened():
ret, frame = cap.read() ret, frame = cap.read()
if ret: if ret:
img_mat = seg.run_predict(frame) img_mat = seg.run_predict(frame)
out.write(img_mat) out.write(img_mat)
id += 1
if id >= 51:
break
else: else:
break break
cap.release() cap.release()
...@@ -259,23 +284,25 @@ def main(argv): ...@@ -259,23 +284,25 @@ def main(argv):
""" """
Entrypoint of the script Entrypoint of the script
""" """
if len(argv) < 3: model_dir = args.model_dir
print('Usage: python infer.py /path/to/model/ /path/to/video') use_gpu = args.use_gpu
return
model_dir = sys.argv[1]
input_path = sys.argv[2]
use_gpu = int(sys.argv[3]) if len(sys.argv) >= 4 else 0
# Init model # Init model
mean = [104.008, 116.669, 122.675] mean = [104.008, 116.669, 122.675]
scale = [1.0, 1.0, 1.0] scale = [1.0, 1.0, 1.0]
eval_size = (192, 192) eval_size = (192, 192)
seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
# Run Predicting on a video and result will be saved as result.avi if args.use_camera:
#predict_camera(seg) # if enable input video stream from video
predict_video(seg, input_path) predict_camera(seg)
#predict_image(seg, input_path) elif args.video_path:
# if video_path valid, do predicting on video
predict_video(seg, args.video_path)
elif args.img_path:
# if img_path valid, do predicting on the image
predict_image(seg, args.img_path)
if __name__ == "__main__": if __name__ == "__main__":
main(sys.argv) argv = parse_args()
main(argv)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册