提交 f71ef43f 编写于 作者: S sjtubinlong

RealTimeHumanSeg: add comments

上级 07ae0bdf
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Python Inference solution for realtime humansegmentation""" """实时人像分割Python预测部署"""
import os import os
import argparse import argparse
...@@ -24,29 +24,29 @@ import paddle.fluid as fluid ...@@ -24,29 +24,29 @@ import paddle.fluid as fluid
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""Optical flow tracking for human segmentation """计算光流跟踪匹配点和光流图
Args: 输入参数:
pre_gray: Grayscale of previous frame. pre_gray: 上一帧灰度图
cur_gray: Grayscale of current frame. cur_gray: 当前帧灰度图
prev_cfd: Optical flow of previous frame. prev_cfd: 上一帧光流图
dl_weights: Merged weights data. dl_weights: 融合权重图
disflow: A data structure represents optical flow. disflow: 光流数据结构
Returns: 返回值:
is_track: Binary graph, whethe a pixel matched with a optical flow point. is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: tracking optical flow image. track_cfd: 光流跟踪图
""" """
check_thres = 8 check_thres = 8
hgt, wdh = pre_gray.shape[:2] hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd) track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray) is_track = np.zeros_like(pre_gray)
# compute forward optical flow # 计算前向光流
flow_fw = disflow.calc(pre_gray, cur_gray, None) flow_fw = disflow.calc(pre_gray, cur_gray, None)
# compute backword optical flow # 计算后向光流
flow_bw = disflow.calc(cur_gray, pre_gray, None) flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5) get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5)
for row in range(hgt): for row in range(hgt):
for col in range(wdh): for col in range(wdh):
# Calculate new coordinate after optfow process. # 计算光流处理后对应点坐标
# (row, col) -> (cur_x, cur_y) # (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col] fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0]) dx_fw = get_round(fxy_fw[0])
...@@ -58,11 +58,11 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): ...@@ -58,11 +58,11 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
fxy_bw = flow_bw[cur_y, cur_x] fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0]) dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1]) dy_bw = get_round(fxy_bw[1])
# Filt the Optical flow point with a threshold # 光流移动小于阈值
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw)) lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres: if lmt >= check_thres:
continue continue
# Downgrade still points # 静止点降权
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0: if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05 dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1 is_track[cur_y, cur_x] = 1
...@@ -71,14 +71,14 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): ...@@ -71,14 +71,14 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""Fusion of Optical flow track and segmentation """光流追踪图和人像分割结构融合
Args: 输入参数:
track_cfd: Optical flow track. track_cfd: 光流追踪图
dl_cfd: Segmentation result of current frame. dl_cfd: 当前帧分割结果
dl_weights: Merged weights data. dl_weights: 融合权重图
is_track: Binary graph, whethe a pixel matched with a optical flow point. is_track: 光流点匹配二值图
Returns: 返回值:
cur_cfd: Fusion of Optical flow track and segmentation result. cur_cfd: 光流跟踪图和人像分割结果融合图
""" """
cur_cfd = dl_cfd.copy() cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0) idxs = np.where(is_track > 0)
...@@ -97,13 +97,13 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track): ...@@ -97,13 +97,13 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
def threshold_mask(img, thresh_bg, thresh_fg): def threshold_mask(img, thresh_bg, thresh_fg):
"""Threshold mask for image foreground and background """设置背景和前景阈值mask
Args: 输入参数:
img : Original image, an instance of np.uint8 array. img : 原始图像, np.uint8 类型.
thresh_bg : Threshold for background, set to 0 when less than it. thresh_bg : 背景阈值百分比,低于该值置为0.
thresh_fg : Threshold for foreground, set to 1 when greater than it. thresh_fg : 前景阈值百分比,超过该值置为1.
Returns: 返回值:
dst : Image after set thresthold mask, ans instance of np.float32 array. dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型.
""" """
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg) dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1 dst[np.where(dst > 1)] = 1
...@@ -112,13 +112,13 @@ def threshold_mask(img, thresh_bg, thresh_fg): ...@@ -112,13 +112,13 @@ def threshold_mask(img, thresh_bg, thresh_fg):
def optflow_handle(cur_gray, scoremap, is_init): def optflow_handle(cur_gray, scoremap, is_init):
"""Processing optical flow and segmentation result. """光流优化
Args: Args:
cur_gray : Grayscale of current frame. cur_gray : 当前帧灰度图
scoremap : Segmentation result of current frame. scoremap : 当前帧分割结果
is_init : True only when process the first frame of a video. is_init : 是否第一帧
Returns: Returns:
dst : Image after set thresthold mask, ans instance of np.float32 array. dst : 光流追踪图和预测结果融合图, 类型为 np.float32
""" """
width, height = scoremap.shape[0], scoremap.shape[1] width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create( disflow = cv2.DISOpticalFlow_create(
...@@ -145,12 +145,8 @@ def optflow_handle(cur_gray, scoremap, is_init): ...@@ -145,12 +145,8 @@ def optflow_handle(cur_gray, scoremap, is_init):
class HumanSeg: class HumanSeg:
"""Human Segmentation Class """人像分割类
This Class instance will load the inference model and do inference 封装了人像分割模型的加载,数据预处理,预测,后处理等
on input image object.
It includes the key stages for a object segmentation inference task.
Call run_predict on your image and it will return a processed image.
""" """
def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False): def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
...@@ -160,10 +156,10 @@ class HumanSeg: ...@@ -160,10 +156,10 @@ class HumanSeg:
self.load_model(model_dir, use_gpu) self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu): def load_model(self, model_dir, use_gpu):
"""Load paddle inference model. """加载模型并创建predictor
Args: Args:
model_dir: The inference model path includes `__model__` and `__params__`. model_dir: 预测模型路径, 包含 `__model__` 和 `__params__`
use_gpu: Enable gpu if use_gpu is True use_gpu: 是否使用GPU加速
""" """
prog_file = os.path.join(model_dir, '__model__') prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__') params_file = os.path.join(model_dir, '__params__')
...@@ -179,12 +175,12 @@ class HumanSeg: ...@@ -179,12 +175,12 @@ class HumanSeg:
self.predictor = fluid.core.create_paddle_predictor(config) self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image): def preprocess(self, image):
"""Preprocess input image. """图像预处理
Convert hwc_rgb to chw_bgr. hwc_rgb 转换为 chw_bgr,并进行归一化
Args: 输入参数:
image: The input opencv image object. image: 原始图像
Returns: 返回值:
A preprocessed image object. 经过预处理后的图片结果
""" """
img_mat = cv2.resize( img_mat = cv2.resize(
image, self.eval_size, interpolation=cv2.INTER_LINEAR) image, self.eval_size, interpolation=cv2.INTER_LINEAR)
...@@ -200,18 +196,18 @@ class HumanSeg: ...@@ -200,18 +196,18 @@ class HumanSeg:
return img_mat return img_mat
def postprocess(self, image, output_data): def postprocess(self, image, output_data):
"""Postprocess the inference result and original input image. """对预测结果进行后处理
Args: Args:
image: The original opencv image object. image: 原始图,opencv 图片对象
output_data: The inference output of paddle's humansegmentation model. output_data: Paddle预测结果原始数据
Returns: Returns:
The result merged original image and segmentation result with optical-flow improvement. 原图和预测结果融合并做了光流优化的结果图
""" """
scoremap = output_data[0, 1, :, :] scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8) scoremap = (scoremap * 255).astype(np.uint8)
ori_h, ori_w = image.shape[0], image.shape[1] ori_h, ori_w = image.shape[0], image.shape[1]
evl_h, evl_w = self.eval_size[0], self.eval_size[1] evl_h, evl_w = self.eval_size[0], self.eval_size[1]
# optical flow processing # 光流处理
cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (evl_w, evl_h)) cur_gray = cv2.resize(cur_gray, (evl_w, evl_h))
optflow_map = optflow_handle(cur_gray, scoremap, False) optflow_map = optflow_handle(cur_gray, scoremap, False)
...@@ -224,12 +220,11 @@ class HumanSeg: ...@@ -224,12 +220,11 @@ class HumanSeg:
return comb return comb
def run_predict(self, image): def run_predict(self, image):
"""Run Predicting on an opencv image object. """运行预测并返回可视化结果图
Preprocess the image, do inference, and then postprocess the infering output. 输入参数:
Args: image: 需要预测的原始图, opencv图片对象
image: A valid opencv image object. 返回值:
Returns: 可视化的预测结果图
The segmentation result which represents as an opencv image object.
""" """
im_mat = self.preprocess(image) im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32')) im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
...@@ -239,13 +234,8 @@ class HumanSeg: ...@@ -239,13 +234,8 @@ class HumanSeg:
def predict_image(seg, image_path): def predict_image(seg, image_path):
"""Do Predicting on a image file. """对图片文件进行分割
Decoding the image file and do predicting on it. 结果保存到`result.jpeg`文件中
The result will be saved as `result.jpeg`.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
image_path: Path of the image file needs to be processed.
""" """
img_mat = cv2.imread(image_path) img_mat = cv2.imread(image_path)
img_mat = seg.run_predict(img_mat) img_mat = seg.run_predict(img_mat)
...@@ -253,13 +243,8 @@ def predict_image(seg, image_path): ...@@ -253,13 +243,8 @@ def predict_image(seg, image_path):
def predict_video(seg, video_path): def predict_video(seg, video_path):
"""Do Predicting on a video file. """对视频文件进行分割
Decoding the video file and do predicting on each frame. 结果保存到`result.avi`文件中
All result will be saved as `result.avi`.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
video_path: Path of a video file needs to be processed.
""" """
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): if not cap.isOpened():
...@@ -268,11 +253,11 @@ def predict_video(seg, video_path): ...@@ -268,11 +253,11 @@ def predict_video(seg, video_path):
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
# Result Video Writer # 用于保存预测结果视频
out = cv2.VideoWriter('result.avi', out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
(width, height)) (width, height))
# Start capturing from video # 开始获取视频帧
while cap.isOpened(): while cap.isOpened():
ret, frame = cap.read() ret, frame = cap.read()
if ret: if ret:
...@@ -285,12 +270,8 @@ def predict_video(seg, video_path): ...@@ -285,12 +270,8 @@ def predict_video(seg, video_path):
def predict_camera(seg): def predict_camera(seg):
"""Do Predicting on a camera video stream. """从摄像头获取视频流进行预测
Capturing each video frame from camera and do predicting on it. 视频分割结果实时显示到可视化窗口中
All result frames will be shown in a GUI window.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
""" """
cap = cv2.VideoCapture(0) cap = cv2.VideoCapture(0)
if not cap.isOpened(): if not cap.isOpened():
...@@ -310,36 +291,30 @@ def predict_camera(seg): ...@@ -310,36 +291,30 @@ def predict_camera(seg):
def main(args): def main(args):
"""Real Entrypoint of the script. """预测程序入口
Load the human segmentation inference model and do predicting on the input resource. 完成模型加载, 对视频、摄像头、图片文件等预测过程
Support three types of input: camera stream / video file / image file.
Args:
args: The command-line args for inference model.
Open camera and do predicting on camera stream while `args.use_camera` is true.
Open the video file and do predicting on it while `args.video_path` is valid.
Open the image file and do predicting on it while `args.img_path` is valid.
""" """
model_dir = args.model_dir model_dir = args.model_dir
use_gpu = args.use_gpu use_gpu = args.use_gpu
# Init model # 加载模型
mean = [104.008, 116.669, 122.675] mean = [104.008, 116.669, 122.675]
scale = [1.0, 1.0, 1.0] scale = [1.0, 1.0, 1.0]
eval_size = (192, 192) eval_size = (192, 192)
seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu) seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
if args.use_camera: if args.use_camera:
# if enable input video stream from camera # 开启摄像头
predict_camera(seg) predict_camera(seg)
elif args.video_path: elif args.video_path:
# if video_path valid, do predicting on the video # 使用视频文件作为输入
predict_video(seg, args.video_path) predict_video(seg, args.video_path)
elif args.img_path: elif args.img_path:
# if img_path valid, do predicting on the image # 使用图片文件作为输入
predict_image(seg, args.img_path) predict_image(seg, args.img_path)
def parse_args(): def parse_args():
"""Parsing command-line argments """解析命令行参数
""" """
parser = argparse.ArgumentParser('Realtime Human Segmentation') parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument('--model_dir', parser.add_argument('--model_dir',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册