提交 76e2799e 编写于 作者: S sjtubinlong

fix coding style

上级 c36a5ec2
......@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python Inference solution for realtime humansegmentation"""
import os
import argparse
......@@ -21,54 +23,31 @@ import cv2
import paddle.fluid as fluid
def parse_args():
"""
Parsing command argments
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument('--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument('--img_path',
type=str,
default='',
help='path of input image')
parser.add_argument('--video_path',
type=str,
default='',
help='path of input video')
parser.add_argument('--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument('--use_gpu',
type=bool,
default=False,
help='enable gpu')
return parser.parse_args()
def get_round(data):
"""
get round of data
"""
rnd = 0.5 if data >= 0 else -0.5
return (int)(data + rnd)
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""
human segmentation tracking
"""Optical flow tracking for human segmentation
Args:
pre_gray: Grayscale of previous frame.
cur_gray: Grayscale of current frame.
prev_cfd: Optical flow of previous frame.
dl_weights: Merged weights data.
disflow: A data structure represents optical flow.
Returns:
is_track: Binary graph, whethe a pixel matched with a optical flow point.
track_cfd: tracking optical flow image.
"""
check_thres = 8
hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
# compute forward optical flow
flow_fw = disflow.calc(pre_gray, cur_gray, None)
# compute backword optical flow
flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5)
for row in range(hgt):
for col in range(wdh):
# Calculate new coordinate after optfow process.
# (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + col
......@@ -79,20 +58,27 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
# Filt the Optical flow point with a threshold
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue
# Downgrade still points
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights
def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""
human segmentation tracking fuse
"""Fusion of Optical flow track and segmentation
Args:
track_cfd: Optical flow track.
dl_cfd: Segmentation result of current frame.
dl_weights: Merged weights data.
is_track: Binary graph, whethe a pixel matched with a optical flow point.
Returns:
cur_cfd: Fusion of Optical flow track and segmentation result.
"""
cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
......@@ -111,8 +97,13 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
def threshold_mask(img, thresh_bg, thresh_fg):
"""
threshold mask
"""Threshold mask for image foreground and background
Args:
img : Original image, an instance of np.uint8 array.
thresh_bg : Threshold for background, set to 0 when less than it.
thresh_fg : Threshold for foreground, set to 1 when greater than it.
Returns:
dst : Image after set thresthold mask, ans instance of np.float32 array.
"""
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
......@@ -121,8 +112,13 @@ def threshold_mask(img, thresh_bg, thresh_fg):
def optflow_handle(cur_gray, scoremap, is_init):
"""
optical flow handling
"""Processing optical flow and segmentation result.
Args:
cur_gray : Grayscale of current frame.
scoremap : Segmentation result of current frame.
is_init : True only when process the first frame of a video.
Returns:
dst : Image after set thresthold mask, ans instance of np.float32 array.
"""
width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(
......@@ -149,18 +145,25 @@ def optflow_handle(cur_gray, scoremap, is_init):
class HumanSeg:
"""
Human Segmentation Class
"""Human Segmentation Class
This Class instance will load the inference model and do inference
on input image object.
It includes the key stages for a object segmentation inference task.
Call run_predict on your image and it will return a processed image.
"""
def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
self.mean = np.array(mean).reshape((3, 1, 1))
self.scale = np.array(scale).reshape((3, 1, 1))
self.eval_size = eval_size
self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu):
"""
Load model from model_dir
"""Load paddle inference model.
Args:
model_dir: The inference model path includes `__model__` and `__params__`.
use_gpu: Enable gpu if use_gpu is True
"""
prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__')
......@@ -176,8 +179,12 @@ class HumanSeg:
self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image):
"""
preprocess image: hwc_rgb to chw_bgr
"""Preprocess input image.
Convert hwc_rgb to chw_bgr.
Args:
image: The input opencv image object.
Returns:
A preprocessed image object.
"""
img_mat = cv2.resize(
image, self.eval_size, interpolation=cv2.INTER_LINEAR)
......@@ -193,8 +200,12 @@ class HumanSeg:
return img_mat
def postprocess(self, image, output_data):
"""
postprocess result: merge background with segmentation result
"""Postprocess the inference result and original input image.
Args:
image: The original opencv image object.
output_data: The inference output of paddle's humansegmentation model.
Returns:
The result merged original image and segmentation result with optical-flow improvement.
"""
scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8)
......@@ -213,8 +224,12 @@ class HumanSeg:
return comb
def run_predict(self, image):
"""
run predict: return segmentation image mat
"""Run Predicting on an opencv image object.
Preprocess the image, do inference, and then postprocess the infering output.
Args:
image: A valid opencv image object.
Returns:
The segmentation result which represents as an opencv image object.
"""
im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
......@@ -224,8 +239,13 @@ class HumanSeg:
def predict_image(seg, image_path):
"""
Do Predicting on a single image
"""Do Predicting on a image file.
Decoding the image file and do predicting on it.
The result will be saved as `result.jpeg`.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
image_path: Path of the image file needs to be processed.
"""
img_mat = cv2.imread(image_path)
img_mat = seg.run_predict(img_mat)
......@@ -233,8 +253,13 @@ def predict_image(seg, image_path):
def predict_video(seg, video_path):
"""
Do Predicting on a video
"""Do Predicting on a video file.
Decoding the video file and do predicting on each frame.
All result will be saved as `result.avi`.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
video_path: Path of a video file needs to be processed.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
......@@ -260,8 +285,12 @@ def predict_video(seg, video_path):
def predict_camera(seg):
"""
Do Predicting on a camera video stream: Press q to exit
"""Do Predicting on a camera video stream.
Capturing each video frame from camera and do predicting on it.
All result frames will be shown in a GUI window.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
"""
cap = cv2.VideoCapture(0)
if not cap.isOpened():
......@@ -281,8 +310,14 @@ def predict_camera(seg):
def main(args):
"""
Entrypoint of the script
"""Real Entrypoint of the script.
Load the human segmentation inference model and do predicting on the input resource.
Support three types of input: camera stream / video file / image file.
Args:
args: The command-line args for inference model.
Open camera and do predicting on camera stream while `args.use_camera` is true.
Open the video file and do predicting on it while `args.video_path` is valid.
Open the image file and do predicting on it while `args.img_path` is valid.
"""
model_dir = args.model_dir
use_gpu = args.use_gpu
......@@ -293,16 +328,43 @@ def main(args):
eval_size = (192, 192)
seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
if args.use_camera:
# if enable input video stream from video
# if enable input video stream from camera
predict_camera(seg)
elif args.video_path:
# if video_path valid, do predicting on video
# if video_path valid, do predicting on the video
predict_video(seg, args.video_path)
elif args.img_path:
# if img_path valid, do predicting on the image
predict_image(seg, args.img_path)
def parse_args():
"""Parsing command-line argments
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument('--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument('--img_path',
type=str,
default='',
help='path of input image')
parser.add_argument('--video_path',
type=str,
default='',
help='path of input video')
parser.add_argument('--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument('--use_gpu',
type=bool,
default=False,
help='enable gpu')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册