Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
76e2799e
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
76e2799e
编写于
4月 01, 2020
作者:
S
sjtubinlong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix coding style
上级
c36a5ec2
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
127 addition
and
65 deletion
+127
-65
contrib/RealTimeHumanSeg/python/infer.py
contrib/RealTimeHumanSeg/python/infer.py
+127
-65
未找到文件。
contrib/RealTimeHumanSeg/python/infer.py
浏览文件 @
76e2799e
...
...
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python Inference solution for realtime humansegmentation"""
import
os
import
argparse
...
...
@@ -21,54 +23,31 @@ import cv2
import
paddle.fluid
as
fluid
def
parse_args
():
"""
Parsing command argments
"""
parser
=
argparse
.
ArgumentParser
(
'Realtime Human Segmentation'
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
''
,
help
=
'path of human segmentation model'
)
parser
.
add_argument
(
'--img_path'
,
type
=
str
,
default
=
''
,
help
=
'path of input image'
)
parser
.
add_argument
(
'--video_path'
,
type
=
str
,
default
=
''
,
help
=
'path of input video'
)
parser
.
add_argument
(
'--use_camera'
,
type
=
bool
,
default
=
False
,
help
=
'input video stream from camera'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
False
,
help
=
'enable gpu'
)
return
parser
.
parse_args
()
def
get_round
(
data
):
"""
get round of data
"""
rnd
=
0.5
if
data
>=
0
else
-
0.5
return
(
int
)(
data
+
rnd
)
def
human_seg_tracking
(
pre_gray
,
cur_gray
,
prev_cfd
,
dl_weights
,
disflow
):
"""
human segmentation tracking
"""Optical flow tracking for human segmentation
Args:
pre_gray: Grayscale of previous frame.
cur_gray: Grayscale of current frame.
prev_cfd: Optical flow of previous frame.
dl_weights: Merged weights data.
disflow: A data structure represents optical flow.
Returns:
is_track: Binary graph, whethe a pixel matched with a optical flow point.
track_cfd: tracking optical flow image.
"""
check_thres
=
8
hgt
,
wdh
=
pre_gray
.
shape
[:
2
]
track_cfd
=
np
.
zeros_like
(
prev_cfd
)
is_track
=
np
.
zeros_like
(
pre_gray
)
# compute forward optical flow
flow_fw
=
disflow
.
calc
(
pre_gray
,
cur_gray
,
None
)
# compute backword optical flow
flow_bw
=
disflow
.
calc
(
cur_gray
,
pre_gray
,
None
)
get_round
=
lambda
data
:
(
int
)(
data
+
0.5
)
if
data
>=
0
else
(
int
)(
data
-
0.5
)
for
row
in
range
(
hgt
):
for
col
in
range
(
wdh
):
# Calculate new coordinate after optfow process.
# (row, col) -> (cur_x, cur_y)
fxy_fw
=
flow_fw
[
row
,
col
]
dx_fw
=
get_round
(
fxy_fw
[
0
])
cur_x
=
dx_fw
+
col
...
...
@@ -79,20 +58,27 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
fxy_bw
=
flow_bw
[
cur_y
,
cur_x
]
dx_bw
=
get_round
(
fxy_bw
[
0
])
dy_bw
=
get_round
(
fxy_bw
[
1
])
# Filt the Optical flow point with a threshold
lmt
=
((
dy_fw
+
dy_bw
)
*
(
dy_fw
+
dy_bw
)
+
(
dx_fw
+
dx_bw
)
*
(
dx_fw
+
dx_bw
))
if
lmt
>=
check_thres
:
continue
# Downgrade still points
if
abs
(
dy_fw
)
<=
0
and
abs
(
dx_fw
)
<=
0
and
abs
(
dy_bw
)
<=
0
and
abs
(
dx_bw
)
<=
0
:
dl_weights
[
cur_y
,
cur_x
]
=
0.05
is_track
[
cur_y
,
cur_x
]
=
1
track_cfd
[
cur_y
,
cur_x
]
=
prev_cfd
[
row
,
col
]
return
track_cfd
,
is_track
,
dl_weights
def
human_seg_track_fuse
(
track_cfd
,
dl_cfd
,
dl_weights
,
is_track
):
"""
human segmentation tracking fuse
"""Fusion of Optical flow track and segmentation
Args:
track_cfd: Optical flow track.
dl_cfd: Segmentation result of current frame.
dl_weights: Merged weights data.
is_track: Binary graph, whethe a pixel matched with a optical flow point.
Returns:
cur_cfd: Fusion of Optical flow track and segmentation result.
"""
cur_cfd
=
dl_cfd
.
copy
()
idxs
=
np
.
where
(
is_track
>
0
)
...
...
@@ -111,8 +97,13 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
def
threshold_mask
(
img
,
thresh_bg
,
thresh_fg
):
"""
threshold mask
"""Threshold mask for image foreground and background
Args:
img : Original image, an instance of np.uint8 array.
thresh_bg : Threshold for background, set to 0 when less than it.
thresh_fg : Threshold for foreground, set to 1 when greater than it.
Returns:
dst : Image after set thresthold mask, ans instance of np.float32 array.
"""
dst
=
(
img
/
255.0
-
thresh_bg
)
/
(
thresh_fg
-
thresh_bg
)
dst
[
np
.
where
(
dst
>
1
)]
=
1
...
...
@@ -121,8 +112,13 @@ def threshold_mask(img, thresh_bg, thresh_fg):
def
optflow_handle
(
cur_gray
,
scoremap
,
is_init
):
"""
optical flow handling
"""Processing optical flow and segmentation result.
Args:
cur_gray : Grayscale of current frame.
scoremap : Segmentation result of current frame.
is_init : True only when process the first frame of a video.
Returns:
dst : Image after set thresthold mask, ans instance of np.float32 array.
"""
width
,
height
=
scoremap
.
shape
[
0
],
scoremap
.
shape
[
1
]
disflow
=
cv2
.
DISOpticalFlow_create
(
...
...
@@ -149,18 +145,25 @@ def optflow_handle(cur_gray, scoremap, is_init):
class
HumanSeg
:
"""
Human Segmentation Class
"""Human Segmentation Class
This Class instance will load the inference model and do inference
on input image object.
It includes the key stages for a object segmentation inference task.
Call run_predict on your image and it will return a processed image.
"""
def
__init__
(
self
,
model_dir
,
mean
,
scale
,
eval_size
,
use_gpu
=
False
):
self
.
mean
=
np
.
array
(
mean
).
reshape
((
3
,
1
,
1
))
self
.
scale
=
np
.
array
(
scale
).
reshape
((
3
,
1
,
1
))
self
.
eval_size
=
eval_size
self
.
load_model
(
model_dir
,
use_gpu
)
def
load_model
(
self
,
model_dir
,
use_gpu
):
"""
Load model from model_dir
"""Load paddle inference model.
Args:
model_dir: The inference model path includes `__model__` and `__params__`.
use_gpu: Enable gpu if use_gpu is True
"""
prog_file
=
os
.
path
.
join
(
model_dir
,
'__model__'
)
params_file
=
os
.
path
.
join
(
model_dir
,
'__params__'
)
...
...
@@ -176,8 +179,12 @@ class HumanSeg:
self
.
predictor
=
fluid
.
core
.
create_paddle_predictor
(
config
)
def
preprocess
(
self
,
image
):
"""
preprocess image: hwc_rgb to chw_bgr
"""Preprocess input image.
Convert hwc_rgb to chw_bgr.
Args:
image: The input opencv image object.
Returns:
A preprocessed image object.
"""
img_mat
=
cv2
.
resize
(
image
,
self
.
eval_size
,
interpolation
=
cv2
.
INTER_LINEAR
)
...
...
@@ -193,8 +200,12 @@ class HumanSeg:
return
img_mat
def
postprocess
(
self
,
image
,
output_data
):
"""
postprocess result: merge background with segmentation result
"""Postprocess the inference result and original input image.
Args:
image: The original opencv image object.
output_data: The inference output of paddle's humansegmentation model.
Returns:
The result merged original image and segmentation result with optical-flow improvement.
"""
scoremap
=
output_data
[
0
,
1
,
:,
:]
scoremap
=
(
scoremap
*
255
).
astype
(
np
.
uint8
)
...
...
@@ -213,8 +224,12 @@ class HumanSeg:
return
comb
def
run_predict
(
self
,
image
):
"""
run predict: return segmentation image mat
"""Run Predicting on an opencv image object.
Preprocess the image, do inference, and then postprocess the infering output.
Args:
image: A valid opencv image object.
Returns:
The segmentation result which represents as an opencv image object.
"""
im_mat
=
self
.
preprocess
(
image
)
im_tensor
=
fluid
.
core
.
PaddleTensor
(
im_mat
.
copy
().
astype
(
'float32'
))
...
...
@@ -224,8 +239,13 @@ class HumanSeg:
def
predict_image
(
seg
,
image_path
):
"""
Do Predicting on a single image
"""Do Predicting on a image file.
Decoding the image file and do predicting on it.
The result will be saved as `result.jpeg`.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
image_path: Path of the image file needs to be processed.
"""
img_mat
=
cv2
.
imread
(
image_path
)
img_mat
=
seg
.
run_predict
(
img_mat
)
...
...
@@ -233,8 +253,13 @@ def predict_image(seg, image_path):
def
predict_video
(
seg
,
video_path
):
"""
Do Predicting on a video
"""Do Predicting on a video file.
Decoding the video file and do predicting on each frame.
All result will be saved as `result.avi`.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
video_path: Path of a video file needs to be processed.
"""
cap
=
cv2
.
VideoCapture
(
video_path
)
if
not
cap
.
isOpened
():
...
...
@@ -260,8 +285,12 @@ def predict_video(seg, video_path):
def
predict_camera
(
seg
):
"""
Do Predicting on a camera video stream: Press q to exit
"""Do Predicting on a camera video stream.
Capturing each video frame from camera and do predicting on it.
All result frames will be shown in a GUI window.
Args:
seg: The HumanSeg Object which holds a inference model.
Do preprocessing / predicting / postprocessing on a input image object.
"""
cap
=
cv2
.
VideoCapture
(
0
)
if
not
cap
.
isOpened
():
...
...
@@ -281,8 +310,14 @@ def predict_camera(seg):
def
main
(
args
):
"""
Entrypoint of the script
"""Real Entrypoint of the script.
Load the human segmentation inference model and do predicting on the input resource.
Support three types of input: camera stream / video file / image file.
Args:
args: The command-line args for inference model.
Open camera and do predicting on camera stream while `args.use_camera` is true.
Open the video file and do predicting on it while `args.video_path` is valid.
Open the image file and do predicting on it while `args.img_path` is valid.
"""
model_dir
=
args
.
model_dir
use_gpu
=
args
.
use_gpu
...
...
@@ -293,16 +328,43 @@ def main(args):
eval_size
=
(
192
,
192
)
seg
=
HumanSeg
(
model_dir
,
mean
,
scale
,
eval_size
,
use_gpu
)
if
args
.
use_camera
:
# if enable input video stream from
video
# if enable input video stream from
camera
predict_camera
(
seg
)
elif
args
.
video_path
:
# if video_path valid, do predicting on video
# if video_path valid, do predicting on
the
video
predict_video
(
seg
,
args
.
video_path
)
elif
args
.
img_path
:
# if img_path valid, do predicting on the image
predict_image
(
seg
,
args
.
img_path
)
def
parse_args
():
"""Parsing command-line argments
"""
parser
=
argparse
.
ArgumentParser
(
'Realtime Human Segmentation'
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
''
,
help
=
'path of human segmentation model'
)
parser
.
add_argument
(
'--img_path'
,
type
=
str
,
default
=
''
,
help
=
'path of input image'
)
parser
.
add_argument
(
'--video_path'
,
type
=
str
,
default
=
''
,
help
=
'path of input video'
)
parser
.
add_argument
(
'--use_camera'
,
type
=
bool
,
default
=
False
,
help
=
'input video stream from camera'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
False
,
help
=
'enable gpu'
)
return
parser
.
parse_args
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
main
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录