Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
35e66572
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
35e66572
编写于
6月 23, 2021
作者:
Z
zhiboniu
提交者:
GitHub
6月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add hrnet mpii dataset; (#3460)
add dark deploy supported, mpii deploy supported;
上级
e6919efd
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
425 addition
and
44 deletion
+425
-44
configs/keypoint/README.md
configs/keypoint/README.md
+10
-4
configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml
configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml
+0
-3
configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml
configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml
+0
-3
configs/keypoint/hrnet/hrnet_w32_256x192.yml
configs/keypoint/hrnet/hrnet_w32_256x192.yml
+0
-3
configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml
configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml
+130
-0
configs/keypoint/hrnet/hrnet_w32_384x288.yml
configs/keypoint/hrnet/hrnet_w32_384x288.yml
+0
-3
deploy/python/keypoint_det_unite_infer.py
deploy/python/keypoint_det_unite_infer.py
+2
-1
deploy/python/keypoint_infer.py
deploy/python/keypoint_infer.py
+6
-3
deploy/python/keypoint_postprocess.py
deploy/python/keypoint_postprocess.py
+66
-12
deploy/python/keypoint_visualize.py
deploy/python/keypoint_visualize.py
+12
-5
deploy/python/topdown_unite_utils.py
deploy/python/topdown_unite_utils.py
+5
-0
deploy/python/utils.py
deploy/python/utils.py
+5
-0
ppdet/data/source/keypoint_coco.py
ppdet/data/source/keypoint_coco.py
+13
-5
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+10
-1
ppdet/metrics/keypoint_metrics.py
ppdet/metrics/keypoint_metrics.py
+166
-1
未找到文件。
configs/keypoint/README.md
浏览文件 @
35e66572
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#### Model Zoo
#### Model Zoo
COCO数据集
| 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 |
| 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 |
| :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- |
| :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- |
| HigherHRNet | 512 | 32 | 67.1 |
[
higherhrnet_hrnet_w32_512.pdparams
](
https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams
)
|
[
config
](
./higherhrnet/higherhrnet_hrnet_w32_512.yml
)
|
| HigherHRNet | 512 | 32 | 67.1 |
[
higherhrnet_hrnet_w32_512.pdparams
](
https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams
)
|
[
config
](
./higherhrnet/higherhrnet_hrnet_w32_512.yml
)
|
...
@@ -25,6 +25,12 @@
...
@@ -25,6 +25,12 @@
| HRNet+DarkPose | 384x288 | 32 | 78.3 |
[
dark_hrnet_w32_384x288.pdparams
](
https://paddledet.bj.bcebos.com/models/keypoint/dark_hrnet_w32_384x288.pdparams
)
|
[
config
](
./hrnet/dark_hrnet_w32_384x288.yml
)
|
| HRNet+DarkPose | 384x288 | 32 | 78.3 |
[
dark_hrnet_w32_384x288.pdparams
](
https://paddledet.bj.bcebos.com/models/keypoint/dark_hrnet_w32_384x288.pdparams
)
|
[
config
](
./hrnet/dark_hrnet_w32_384x288.yml
)
|
备注: Top-Down模型测试AP结果基于GroundTruth标注框
备注: Top-Down模型测试AP结果基于GroundTruth标注框
MPII数据集
| 模型 | 输入尺寸 | 通道数 | PCKh(Mean) | PCKh(Mean@0.1) | 模型下载 | 配置文件 |
| :---- | -------- | ------ | :--------: | :------------: | :----------------------------------------------------------: | -------------------------------------------- |
| HRNet | 256x256 | 32 | 90.6 | 38.5 |
[
hrnet_w32_256x256_mpii.pdparams
](
https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x256_mpii.pdparams
)
|
[
config
](
./hrnet/hrnet_w32_256x256_mpii.yml
)
|
## 快速开始
## 快速开始
### 1、环境安装
### 1、环境安装
...
@@ -74,9 +80,9 @@ python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w
...
@@ -74,9 +80,9 @@ python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w
#部署推理
#部署推理
#keypoint top-down/bottom-up 单独推理,该模式下top-down模型只支持单人截图预测。
#keypoint top-down/bottom-up 单独推理,该模式下top-down模型只支持单人截图预测。
python deploy/python/keypoint_infer.py
--model_dir
=
output_inference/higherhrnet_hrnet_w32_512/
--image_file
=
./demo/000000014439_640x640.jpg
--
use_gpu
=
True
--threshold
=
0.5
python deploy/python/keypoint_infer.py
--model_dir
=
output_inference/higherhrnet_hrnet_w32_512/
--image_file
=
./demo/000000014439_640x640.jpg
--
device
=
gpu
--threshold
=
0.5
python deploy/python/keypoint_infer.py
--model_dir
=
output_inference/hrnet_w32_384x288/
--image_file
=
./demo/hrnet_demo.jpg
--
use_gpu
=
True
--threshold
=
0.5
python deploy/python/keypoint_infer.py
--model_dir
=
output_inference/hrnet_w32_384x288/
--image_file
=
./demo/hrnet_demo.jpg
--
device
=
gpu
--threshold
=
0.5
#keypoint top-down模型 + detector 检测联合部署推理(联合推理只支持top-down方式)
#keypoint top-down模型 + detector 检测联合部署推理(联合推理只支持top-down方式)
python deploy/python/keypoint_det_unite_infer.py
--det_model_dir
=
output_inference/ppyolo_r50vd_dcn_2x_coco/
--keypoint_model_dir
=
output_inference/hrnet_w32_384x288/
--video_file
=
../video/xxx.mp4
--
use_gpu
=
True
python deploy/python/keypoint_det_unite_infer.py
--det_model_dir
=
output_inference/ppyolo_r50vd_dcn_2x_coco/
--keypoint_model_dir
=
output_inference/hrnet_w32_384x288/
--video_file
=
../video/xxx.mp4
--
device
=
gpu
```
```
configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml
浏览文件 @
35e66572
...
@@ -118,9 +118,6 @@ EvalReader:
...
@@ -118,9 +118,6 @@ EvalReader:
sample_transforms
:
sample_transforms
:
-
TopDownAffine
:
-
TopDownAffine
:
trainsize
:
*trainsize
trainsize
:
*trainsize
-
ToHeatmapsTopDown_DARK
:
hmsize
:
*hmsize
sigma
:
2
batch_transforms
:
batch_transforms
:
-
NormalizeImage
:
-
NormalizeImage
:
mean
:
*global_mean
mean
:
*global_mean
...
...
configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml
浏览文件 @
35e66572
...
@@ -118,9 +118,6 @@ EvalReader:
...
@@ -118,9 +118,6 @@ EvalReader:
sample_transforms
:
sample_transforms
:
-
TopDownAffine
:
-
TopDownAffine
:
trainsize
:
*trainsize
trainsize
:
*trainsize
-
ToHeatmapsTopDown_DARK
:
hmsize
:
*hmsize
sigma
:
2
batch_transforms
:
batch_transforms
:
-
NormalizeImage
:
-
NormalizeImage
:
mean
:
*global_mean
mean
:
*global_mean
...
...
configs/keypoint/hrnet/hrnet_w32_256x192.yml
浏览文件 @
35e66572
...
@@ -118,9 +118,6 @@ EvalReader:
...
@@ -118,9 +118,6 @@ EvalReader:
sample_transforms
:
sample_transforms
:
-
TopDownAffine
:
-
TopDownAffine
:
trainsize
:
*trainsize
trainsize
:
*trainsize
-
ToHeatmapsTopDown
:
hmsize
:
*hmsize
sigma
:
2
batch_transforms
:
batch_transforms
:
-
NormalizeImage
:
-
NormalizeImage
:
mean
:
*global_mean
mean
:
*global_mean
...
...
configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml
0 → 100644
浏览文件 @
35e66572
use_gpu
:
true
log_iter
:
5
save_dir
:
output
snapshot_epoch
:
10
weights
:
output/hrnet_w32_256x256_mpii/model_final
epoch
:
210
num_joints
:
&num_joints
16
pixel_std
:
&pixel_std
200
metric
:
KeyPointTopDownMPIIEval
num_classes
:
1
train_height
:
&train_height
256
train_width
:
&train_width
256
trainsize
:
&trainsize
[
*train_width
,
*train_height
]
hmsize
:
&hmsize
[
64
,
64
]
flip_perm
:
&flip_perm
[[
0
,
5
],
[
1
,
4
],
[
2
,
3
],
[
10
,
15
],
[
11
,
14
],
[
12
,
13
]]
#####model
architecture
:
TopDownHRNet
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
TopDownHRNet
:
backbone
:
HRNet
post_process
:
HRNetPostProcess
flip_perm
:
*flip_perm
num_joints
:
*num_joints
width
:
&width
32
loss
:
KeyPointMSELoss
HRNet
:
width
:
*width
freeze_at
:
-1
freeze_norm
:
false
return_idx
:
[
0
]
KeyPointMSELoss
:
use_target_weight
:
true
#####optimizer
LearningRate
:
base_lr
:
0.0005
schedulers
:
-
!PiecewiseDecay
milestones
:
[
170
,
200
]
gamma
:
0.1
-
!LinearWarmup
start_factor
:
0.001
steps
:
1000
OptimizerBuilder
:
optimizer
:
type
:
Adam
regularizer
:
factor
:
0.0
type
:
L2
#####data
TrainDataset
:
!KeypointTopDownMPIIDataset
image_dir
:
images
anno_path
:
annotations/mpii_train.json
dataset_dir
:
dataset/mpii
num_joints
:
*num_joints
EvalDataset
:
!KeypointTopDownMPIIDataset
image_dir
:
images
anno_path
:
annotations/mpii_val.json
dataset_dir
:
dataset/mpii
num_joints
:
*num_joints
TestDataset
:
!ImageFolder
anno_path
:
dataset/coco/keypoint_imagelist.txt
worker_num
:
4
global_mean
:
&global_mean
[
0.485
,
0.456
,
0.406
]
global_std
:
&global_std
[
0.229
,
0.224
,
0.225
]
TrainReader
:
sample_transforms
:
-
RandomFlipHalfBodyTransform
:
scale
:
0.5
rot
:
40
num_joints_half_body
:
8
prob_half_body
:
0.3
pixel_std
:
*pixel_std
trainsize
:
*trainsize
upper_body_ids
:
[
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
]
flip_pairs
:
*flip_perm
-
TopDownAffine
:
trainsize
:
*trainsize
-
ToHeatmapsTopDown
:
hmsize
:
*hmsize
sigma
:
2
batch_transforms
:
-
NormalizeImage
:
mean
:
*global_mean
std
:
*global_std
is_scale
:
true
-
Permute
:
{}
batch_size
:
64
shuffle
:
true
drop_last
:
false
EvalReader
:
sample_transforms
:
-
TopDownAffine
:
trainsize
:
*trainsize
batch_transforms
:
-
NormalizeImage
:
mean
:
*global_mean
std
:
*global_std
is_scale
:
true
-
Permute
:
{}
batch_size
:
16
TestReader
:
sample_transforms
:
-
Decode
:
{}
-
TopDownEvalAffine
:
trainsize
:
*trainsize
-
NormalizeImage
:
mean
:
*global_mean
std
:
*global_std
is_scale
:
true
-
Permute
:
{}
batch_size
:
1
configs/keypoint/hrnet/hrnet_w32_384x288.yml
浏览文件 @
35e66572
...
@@ -119,9 +119,6 @@ EvalReader:
...
@@ -119,9 +119,6 @@ EvalReader:
sample_transforms
:
sample_transforms
:
-
TopDownAffine
:
-
TopDownAffine
:
trainsize
:
*trainsize
trainsize
:
*trainsize
-
ToHeatmapsTopDown
:
hmsize
:
*hmsize
sigma
:
2
batch_transforms
:
batch_transforms
:
-
NormalizeImage
:
-
NormalizeImage
:
mean
:
*global_mean
mean
:
*global_mean
...
...
deploy/python/keypoint_det_unite_infer.py
浏览文件 @
35e66572
...
@@ -178,7 +178,8 @@ def main():
...
@@ -178,7 +178,8 @@ def main():
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
)
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
use_dark
=
FLAGS
.
use_dark
)
# predict from video file or camera video stream
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
...
...
deploy/python/keypoint_infer.py
浏览文件 @
35e66572
...
@@ -63,7 +63,8 @@ class KeyPoint_Detector(object):
...
@@ -63,7 +63,8 @@ class KeyPoint_Detector(object):
trt_opt_shape
=
640
,
trt_opt_shape
=
640
,
trt_calib_mode
=
False
,
trt_calib_mode
=
False
,
cpu_threads
=
1
,
cpu_threads
=
1
,
enable_mkldnn
=
False
):
enable_mkldnn
=
False
,
use_dark
=
True
):
self
.
pred_config
=
pred_config
self
.
pred_config
=
pred_config
self
.
predictor
,
self
.
config
=
load_predictor
(
self
.
predictor
,
self
.
config
=
load_predictor
(
model_dir
,
model_dir
,
...
@@ -79,6 +80,7 @@ class KeyPoint_Detector(object):
...
@@ -79,6 +80,7 @@ class KeyPoint_Detector(object):
enable_mkldnn
=
enable_mkldnn
)
enable_mkldnn
=
enable_mkldnn
)
self
.
det_times
=
Timer
()
self
.
det_times
=
Timer
()
self
.
cpu_mem
,
self
.
gpu_mem
,
self
.
gpu_util
=
0
,
0
,
0
self
.
cpu_mem
,
self
.
gpu_mem
,
self
.
gpu_util
=
0
,
0
,
0
self
.
use_dark
=
use_dark
def
preprocess
(
self
,
im
):
def
preprocess
(
self
,
im
):
preprocess_ops
=
[]
preprocess_ops
=
[]
...
@@ -109,7 +111,7 @@ class KeyPoint_Detector(object):
...
@@ -109,7 +111,7 @@ class KeyPoint_Detector(object):
imshape
=
inputs
[
'im_shape'
][:,
::
-
1
]
imshape
=
inputs
[
'im_shape'
][:,
::
-
1
]
center
=
np
.
round
(
imshape
/
2.
)
center
=
np
.
round
(
imshape
/
2.
)
scale
=
imshape
/
200.
scale
=
imshape
/
200.
keypoint_postprocess
=
HRNetPostProcess
()
keypoint_postprocess
=
HRNetPostProcess
(
use_dark
=
self
.
use_dark
)
results
[
'keypoint'
]
=
keypoint_postprocess
(
np_boxes
,
center
,
scale
)
results
[
'keypoint'
]
=
keypoint_postprocess
(
np_boxes
,
center
,
scale
)
return
results
return
results
else
:
else
:
...
@@ -390,7 +392,8 @@ def main():
...
@@ -390,7 +392,8 @@ def main():
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
)
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
use_dark
=
FLAGS
.
use_dark
)
# predict from video file or camera video stream
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
...
...
deploy/python/keypoint_postprocess.py
浏览文件 @
35e66572
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
scipy.optimize
import
linear_sum_assignment
from
scipy.optimize
import
linear_sum_assignment
from
collections
import
abc
,
defaultdict
from
collections
import
abc
,
defaultdict
import
cv2
import
numpy
as
np
import
numpy
as
np
import
math
import
math
import
paddle
import
paddle
...
@@ -193,6 +194,9 @@ def warp_affine_joints(joints, mat):
...
@@ -193,6 +194,9 @@ def warp_affine_joints(joints, mat):
class
HRNetPostProcess
(
object
):
class
HRNetPostProcess
(
object
):
def
__init__
(
self
,
use_dark
=
True
):
self
.
use_dark
=
use_dark
def
flip_back
(
self
,
output_flipped
,
matched_parts
):
def
flip_back
(
self
,
output_flipped
,
matched_parts
):
assert
output_flipped
.
ndim
==
4
,
\
assert
output_flipped
.
ndim
==
4
,
\
'output_flipped should be [batch_size, num_joints, height, width]'
'output_flipped should be [batch_size, num_joints, height, width]'
...
@@ -242,7 +246,54 @@ class HRNetPostProcess(object):
...
@@ -242,7 +246,54 @@ class HRNetPostProcess(object):
return
preds
,
maxvals
return
preds
,
maxvals
def
get_final_preds
(
self
,
heatmaps
,
center
,
scale
):
def
gaussian_blur
(
self
,
heatmap
,
kernel
):
border
=
(
kernel
-
1
)
//
2
batch_size
=
heatmap
.
shape
[
0
]
num_joints
=
heatmap
.
shape
[
1
]
height
=
heatmap
.
shape
[
2
]
width
=
heatmap
.
shape
[
3
]
for
i
in
range
(
batch_size
):
for
j
in
range
(
num_joints
):
origin_max
=
np
.
max
(
heatmap
[
i
,
j
])
dr
=
np
.
zeros
((
height
+
2
*
border
,
width
+
2
*
border
))
dr
[
border
:
-
border
,
border
:
-
border
]
=
heatmap
[
i
,
j
].
copy
()
dr
=
cv2
.
GaussianBlur
(
dr
,
(
kernel
,
kernel
),
0
)
heatmap
[
i
,
j
]
=
dr
[
border
:
-
border
,
border
:
-
border
].
copy
()
heatmap
[
i
,
j
]
*=
origin_max
/
np
.
max
(
heatmap
[
i
,
j
])
return
heatmap
def
dark_parse
(
self
,
hm
,
coord
):
heatmap_height
=
hm
.
shape
[
0
]
heatmap_width
=
hm
.
shape
[
1
]
px
=
int
(
coord
[
0
])
py
=
int
(
coord
[
1
])
if
1
<
px
<
heatmap_width
-
2
and
1
<
py
<
heatmap_height
-
2
:
dx
=
0.5
*
(
hm
[
py
][
px
+
1
]
-
hm
[
py
][
px
-
1
])
dy
=
0.5
*
(
hm
[
py
+
1
][
px
]
-
hm
[
py
-
1
][
px
])
dxx
=
0.25
*
(
hm
[
py
][
px
+
2
]
-
2
*
hm
[
py
][
px
]
+
hm
[
py
][
px
-
2
])
dxy
=
0.25
*
(
hm
[
py
+
1
][
px
+
1
]
-
hm
[
py
-
1
][
px
+
1
]
-
hm
[
py
+
1
][
px
-
1
]
\
+
hm
[
py
-
1
][
px
-
1
])
dyy
=
0.25
*
(
hm
[
py
+
2
*
1
][
px
]
-
2
*
hm
[
py
][
px
]
+
hm
[
py
-
2
*
1
][
px
])
derivative
=
np
.
matrix
([[
dx
],
[
dy
]])
hessian
=
np
.
matrix
([[
dxx
,
dxy
],
[
dxy
,
dyy
]])
if
dxx
*
dyy
-
dxy
**
2
!=
0
:
hessianinv
=
hessian
.
I
offset
=
-
hessianinv
*
derivative
offset
=
np
.
squeeze
(
np
.
array
(
offset
.
T
),
axis
=
0
)
coord
+=
offset
return
coord
def
dark_postprocess
(
self
,
hm
,
coords
,
kernelsize
):
hm
=
self
.
gaussian_blur
(
hm
,
kernelsize
)
hm
=
np
.
maximum
(
hm
,
1e-10
)
hm
=
np
.
log
(
hm
)
for
n
in
range
(
coords
.
shape
[
0
]):
for
p
in
range
(
coords
.
shape
[
1
]):
coords
[
n
,
p
]
=
self
.
dark_parse
(
hm
[
n
][
p
],
coords
[
n
][
p
])
return
coords
def
get_final_preds
(
self
,
heatmaps
,
center
,
scale
,
kernelsize
=
3
):
"""the highest heatvalue location with a quarter offset in the
"""the highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
direction from the highest response to the second highest response.
...
@@ -261,17 +312,20 @@ class HRNetPostProcess(object):
...
@@ -261,17 +312,20 @@ class HRNetPostProcess(object):
heatmap_height
=
heatmaps
.
shape
[
2
]
heatmap_height
=
heatmaps
.
shape
[
2
]
heatmap_width
=
heatmaps
.
shape
[
3
]
heatmap_width
=
heatmaps
.
shape
[
3
]
for
n
in
range
(
coords
.
shape
[
0
]):
if
self
.
use_dark
:
for
p
in
range
(
coords
.
shape
[
1
]):
coords
=
self
.
dark_postprocess
(
heatmaps
,
coords
,
kernelsize
)
hm
=
heatmaps
[
n
][
p
]
else
:
px
=
int
(
math
.
floor
(
coords
[
n
][
p
][
0
]
+
0.5
))
for
n
in
range
(
coords
.
shape
[
0
]):
py
=
int
(
math
.
floor
(
coords
[
n
][
p
][
1
]
+
0.5
))
for
p
in
range
(
coords
.
shape
[
1
]):
if
1
<
px
<
heatmap_width
-
1
and
1
<
py
<
heatmap_height
-
1
:
hm
=
heatmaps
[
n
][
p
]
diff
=
np
.
array
([
px
=
int
(
math
.
floor
(
coords
[
n
][
p
][
0
]
+
0.5
))
hm
[
py
][
px
+
1
]
-
hm
[
py
][
px
-
1
],
py
=
int
(
math
.
floor
(
coords
[
n
][
p
][
1
]
+
0.5
))
hm
[
py
+
1
][
px
]
-
hm
[
py
-
1
][
px
]
if
1
<
px
<
heatmap_width
-
1
and
1
<
py
<
heatmap_height
-
1
:
])
diff
=
np
.
array
([
coords
[
n
][
p
]
+=
np
.
sign
(
diff
)
*
.
25
hm
[
py
][
px
+
1
]
-
hm
[
py
][
px
-
1
],
hm
[
py
+
1
][
px
]
-
hm
[
py
-
1
][
px
]
])
coords
[
n
][
p
]
+=
np
.
sign
(
diff
)
*
.
25
preds
=
coords
.
copy
()
preds
=
coords
.
copy
()
# Transform back
# Transform back
...
...
deploy/python/keypoint_visualize.py
浏览文件 @
35e66572
...
@@ -34,9 +34,16 @@ def draw_pose(imgfile,
...
@@ -34,9 +34,16 @@ def draw_pose(imgfile,
'for example: `pip install matplotlib`.'
)
'for example: `pip install matplotlib`.'
)
raise
e
raise
e
EDGES
=
[(
0
,
1
),
(
0
,
2
),
(
1
,
3
),
(
2
,
4
),
(
3
,
5
),
(
4
,
6
),
(
5
,
7
),
(
6
,
8
),
skeletons
,
scores
=
results
[
'keypoint'
]
(
7
,
9
),
(
8
,
10
),
(
5
,
11
),
(
6
,
12
),
(
11
,
13
),
(
12
,
14
),
(
13
,
15
),
kpt_nums
=
len
(
skeletons
[
0
])
(
14
,
16
),
(
11
,
12
)]
if
kpt_nums
==
17
:
#plot coco keypoint
EDGES
=
[(
0
,
1
),
(
0
,
2
),
(
1
,
3
),
(
2
,
4
),
(
3
,
5
),
(
4
,
6
),
(
5
,
7
),
(
6
,
8
),
(
7
,
9
),
(
8
,
10
),
(
5
,
11
),
(
6
,
12
),
(
11
,
13
),
(
12
,
14
),
(
13
,
15
),
(
14
,
16
),
(
11
,
12
)]
else
:
#plot mpii keypoint
EDGES
=
[(
0
,
1
),
(
1
,
2
),
(
3
,
4
),
(
4
,
5
),
(
2
,
6
),
(
3
,
6
),
(
6
,
7
),
(
7
,
8
),
(
8
,
9
),
(
10
,
11
),
(
11
,
12
),
(
13
,
14
),
(
14
,
15
),
(
8
,
12
),
(
8
,
13
)]
NUM_EDGES
=
len
(
EDGES
)
NUM_EDGES
=
len
(
EDGES
)
colors
=
[[
255
,
0
,
0
],
[
255
,
85
,
0
],
[
255
,
170
,
0
],
[
255
,
255
,
0
],
[
170
,
255
,
0
],
[
85
,
255
,
0
],
[
0
,
255
,
0
],
\
colors
=
[[
255
,
0
,
0
],
[
255
,
85
,
0
],
[
255
,
170
,
0
],
[
255
,
255
,
0
],
[
170
,
255
,
0
],
[
85
,
255
,
0
],
[
0
,
255
,
0
],
\
...
@@ -46,7 +53,7 @@ def draw_pose(imgfile,
...
@@ -46,7 +53,7 @@ def draw_pose(imgfile,
plt
.
figure
()
plt
.
figure
()
img
=
cv2
.
imread
(
imgfile
)
if
type
(
imgfile
)
==
str
else
imgfile
img
=
cv2
.
imread
(
imgfile
)
if
type
(
imgfile
)
==
str
else
imgfile
skeletons
,
scores
=
results
[
'keypoint'
]
color_set
=
results
[
'colors'
]
if
'colors'
in
results
else
None
color_set
=
results
[
'colors'
]
if
'colors'
in
results
else
None
if
'bbox'
in
results
:
if
'bbox'
in
results
:
...
@@ -58,7 +65,7 @@ def draw_pose(imgfile,
...
@@ -58,7 +65,7 @@ def draw_pose(imgfile,
cv2
.
rectangle
(
img
,
(
xmin
,
ymin
),
(
xmax
,
ymax
),
color
,
1
)
cv2
.
rectangle
(
img
,
(
xmin
,
ymin
),
(
xmax
,
ymax
),
color
,
1
)
canvas
=
img
.
copy
()
canvas
=
img
.
copy
()
for
i
in
range
(
17
):
for
i
in
range
(
kpt_nums
):
for
j
in
range
(
len
(
skeletons
)):
for
j
in
range
(
len
(
skeletons
)):
if
skeletons
[
j
][
i
,
2
]
<
visual_thread
:
if
skeletons
[
j
][
i
,
2
]
<
visual_thread
:
continue
continue
...
...
deploy/python/topdown_unite_utils.py
浏览文件 @
35e66572
...
@@ -103,5 +103,10 @@ def argsparser():
...
@@ -103,5 +103,10 @@ def argsparser():
default
=
False
,
default
=
False
,
help
=
"If the model is produced by TRT offline quantitative "
help
=
"If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True."
)
"calibration, trt_calib_mode need to set True."
)
parser
.
add_argument
(
'--use_dark'
,
type
=
bool
,
default
=
True
,
help
=
'whether to use darkpose to get better keypoint position predict '
)
return
parser
return
parser
deploy/python/utils.py
浏览文件 @
35e66572
...
@@ -108,6 +108,11 @@ def argsparser():
...
@@ -108,6 +108,11 @@ def argsparser():
'--save_results'
,
'--save_results'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt).'
)
help
=
'Save tracking results (txt).'
)
parser
.
add_argument
(
'--use_dark'
,
type
=
bool
,
default
=
True
,
help
=
'whether to use darkpose to get better keypoint position predict '
)
return
parser
return
parser
...
...
ppdet/data/source/keypoint_coco.py
浏览文件 @
35e66572
...
@@ -25,7 +25,8 @@ from ppdet.core.workspace import register, serializable
...
@@ -25,7 +25,8 @@ from ppdet.core.workspace import register, serializable
@
serializable
@
serializable
class
KeypointBottomUpBaseDataset
(
DetDataset
):
class
KeypointBottomUpBaseDataset
(
DetDataset
):
"""Base class for bottom-up datasets.
"""Base class for bottom-up datasets. Adapted from
https://github.com/open-mmlab/mmpose
All datasets should subclass it.
All datasets should subclass it.
All subclasses should overwrite:
All subclasses should overwrite:
...
@@ -86,7 +87,8 @@ class KeypointBottomUpBaseDataset(DetDataset):
...
@@ -86,7 +87,8 @@ class KeypointBottomUpBaseDataset(DetDataset):
@
register
@
register
@
serializable
@
serializable
class
KeypointBottomUpCocoDataset
(
KeypointBottomUpBaseDataset
):
class
KeypointBottomUpCocoDataset
(
KeypointBottomUpBaseDataset
):
"""COCO dataset for bottom-up pose estimation.
"""COCO dataset for bottom-up pose estimation. Adapted from
https://github.com/open-mmlab/mmpose
The dataset loads raw features and apply specified transforms
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
to return a dict containing the image tensors and other information.
...
@@ -253,7 +255,8 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
...
@@ -253,7 +255,8 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
@
register
@
register
@
serializable
@
serializable
class
KeypointBottomUpCrowdPoseDataset
(
KeypointBottomUpCocoDataset
):
class
KeypointBottomUpCrowdPoseDataset
(
KeypointBottomUpCocoDataset
):
"""CrowdPose dataset for bottom-up pose estimation.
"""CrowdPose dataset for bottom-up pose estimation. Adapted from
https://github.com/open-mmlab/mmpose
The dataset loads raw features and apply specified transforms
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
to return a dict containing the image tensors and other information.
...
@@ -374,7 +377,9 @@ class KeypointTopDownBaseDataset(DetDataset):
...
@@ -374,7 +377,9 @@ class KeypointTopDownBaseDataset(DetDataset):
@
register
@
register
@
serializable
@
serializable
class
KeypointTopDownCocoDataset
(
KeypointTopDownBaseDataset
):
class
KeypointTopDownCocoDataset
(
KeypointTopDownBaseDataset
):
"""COCO dataset for top-down pose estimation.
"""COCO dataset for top-down pose estimation. Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
The dataset loads raw features and apply specified transforms
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
to return a dict containing the image tensors and other information.
...
@@ -567,7 +572,9 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
...
@@ -567,7 +572,9 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
@
register
@
register
@
serializable
@
serializable
class
KeypointTopDownMPIIDataset
(
KeypointTopDownBaseDataset
):
class
KeypointTopDownMPIIDataset
(
KeypointTopDownBaseDataset
):
"""MPII dataset for topdown pose estimation.
"""MPII dataset for topdown pose estimation. Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
The dataset loads raw features and apply specified transforms
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
to return a dict containing the image tensors and other information.
...
@@ -653,4 +660,5 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
...
@@ -653,4 +660,5 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
'joints'
:
joints
,
'joints'
:
joints
,
'joints_vis'
:
joints_vis
'joints_vis'
:
joints_vis
})
})
print
(
"number length: {}"
.
format
(
len
(
gt_db
)))
self
.
db
=
gt_db
self
.
db
=
gt_db
ppdet/engine/trainer.py
浏览文件 @
35e66572
...
@@ -35,7 +35,7 @@ from ppdet.optimizer import ModelEMA
...
@@ -35,7 +35,7 @@ from ppdet.optimizer import ModelEMA
from
ppdet.core.workspace
import
create
from
ppdet.core.workspace
import
create
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.utils.visualizer
import
visualize_results
,
save_result
from
ppdet.utils.visualizer
import
visualize_results
,
save_result
from
ppdet.metrics
import
Metric
,
COCOMetric
,
VOCMetric
,
WiderFaceMetric
,
get_infer_results
,
KeyPointTopDownCOCOEval
from
ppdet.metrics
import
Metric
,
COCOMetric
,
VOCMetric
,
WiderFaceMetric
,
get_infer_results
,
KeyPointTopDownCOCOEval
,
KeyPointTopDownMPIIEval
from
ppdet.metrics
import
RBoxMetric
from
ppdet.metrics
import
RBoxMetric
from
ppdet.data.source.category
import
get_categories
from
ppdet.data.source.category
import
get_categories
import
ppdet.utils.stats
as
stats
import
ppdet.utils.stats
as
stats
...
@@ -234,6 +234,15 @@ class Trainer(object):
...
@@ -234,6 +234,15 @@ class Trainer(object):
len
(
eval_dataset
),
self
.
cfg
.
num_joints
,
len
(
eval_dataset
),
self
.
cfg
.
num_joints
,
self
.
cfg
.
save_dir
)
self
.
cfg
.
save_dir
)
]
]
elif
self
.
cfg
.
metric
==
'KeyPointTopDownMPIIEval'
:
eval_dataset
=
self
.
cfg
[
'EvalDataset'
]
eval_dataset
.
check_or_download_dataset
()
anno_file
=
eval_dataset
.
get_anno
()
self
.
_metrics
=
[
KeyPointTopDownMPIIEval
(
anno_file
,
len
(
eval_dataset
),
self
.
cfg
.
num_joints
,
self
.
cfg
.
save_dir
)
]
else
:
else
:
logger
.
warn
(
"Metric not support for metric type {}"
.
format
(
logger
.
warn
(
"Metric not support for metric type {}"
.
format
(
self
.
cfg
.
metric
))
self
.
cfg
.
metric
))
...
...
ppdet/metrics/keypoint_metrics.py
浏览文件 @
35e66572
...
@@ -21,11 +21,18 @@ import numpy as np
...
@@ -21,11 +21,18 @@ import numpy as np
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
from
pycocotools.cocoeval
import
COCOeval
from
..modeling.keypoint_utils
import
oks_nms
from
..modeling.keypoint_utils
import
oks_nms
from
scipy.io
import
loadmat
,
savemat
__all__
=
[
'KeyPointTopDownCOCOEval'
]
__all__
=
[
'KeyPointTopDownCOCOEval'
,
'KeyPointTopDownMPIIEval'
]
class
KeyPointTopDownCOCOEval
(
object
):
class
KeyPointTopDownCOCOEval
(
object
):
'''
Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
'''
def
__init__
(
self
,
def
__init__
(
self
,
anno_file
,
anno_file
,
num_samples
,
num_samples
,
...
@@ -200,3 +207,161 @@ class KeyPointTopDownCOCOEval(object):
...
@@ -200,3 +207,161 @@ class KeyPointTopDownCOCOEval(object):
def
get_results
(
self
):
def
get_results
(
self
):
return
self
.
eval_results
return
self
.
eval_results
class
KeyPointTopDownMPIIEval
(
object
):
def
__init__
(
self
,
anno_file
,
num_samples
,
num_joints
,
output_eval
,
oks_thre
=
0.9
):
super
(
KeyPointTopDownMPIIEval
,
self
).
__init__
()
self
.
ann_file
=
anno_file
self
.
reset
()
def
reset
(
self
):
self
.
results
=
[]
self
.
eval_results
=
{}
self
.
idx
=
0
def
update
(
self
,
inputs
,
outputs
):
kpts
,
_
=
outputs
[
'keypoint'
][
0
]
num_images
=
inputs
[
'image'
].
shape
[
0
]
results
=
{}
results
[
'preds'
]
=
kpts
[:,
:,
0
:
3
]
results
[
'boxes'
]
=
np
.
zeros
((
num_images
,
6
))
results
[
'boxes'
][:,
0
:
2
]
=
inputs
[
'center'
].
numpy
()[:,
0
:
2
]
results
[
'boxes'
][:,
2
:
4
]
=
inputs
[
'scale'
].
numpy
()[:,
0
:
2
]
results
[
'boxes'
][:,
4
]
=
np
.
prod
(
inputs
[
'scale'
].
numpy
()
*
200
,
1
)
results
[
'boxes'
][:,
5
]
=
np
.
squeeze
(
inputs
[
'score'
].
numpy
())
results
[
'image_path'
]
=
inputs
[
'image_file'
]
self
.
results
.
append
(
results
)
def
accumulate
(
self
):
self
.
eval_results
=
self
.
evaluate
(
self
.
results
)
def
log
(
self
):
for
item
,
value
in
self
.
eval_results
.
items
():
print
(
"{} : {}"
.
format
(
item
,
value
))
def
get_results
(
self
):
return
self
.
eval_results
def
evaluate
(
self
,
outputs
,
savepath
=
None
):
"""Evaluate PCKh for MPII dataset. Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
Args:
outputs(list(preds, boxes)):
* preds (np.ndarray[N,K,3]): The first two dimensions are
coordinates, score is the third dimension of the array.
* boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
, scale[1],area, score]
Returns:
dict: PCKh for each joint
"""
kpts
=
[]
for
output
in
outputs
:
preds
=
output
[
'preds'
]
batch_size
=
preds
.
shape
[
0
]
for
i
in
range
(
batch_size
):
kpts
.
append
({
'keypoints'
:
preds
[
i
]})
preds
=
np
.
stack
([
kpt
[
'keypoints'
]
for
kpt
in
kpts
])
# convert 0-based index to 1-based index,
# and get the first two dimensions.
preds
=
preds
[...,
:
2
]
+
1.0
if
savepath
is
not
None
:
pred_file
=
os
.
path
.
join
(
savepath
,
'pred.mat'
)
savemat
(
pred_file
,
mdict
=
{
'preds'
:
preds
})
SC_BIAS
=
0.6
threshold
=
0.5
gt_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
self
.
ann_file
),
'mpii_gt_val.mat'
)
gt_dict
=
loadmat
(
gt_file
)
dataset_joints
=
gt_dict
[
'dataset_joints'
]
jnt_missing
=
gt_dict
[
'jnt_missing'
]
pos_gt_src
=
gt_dict
[
'pos_gt_src'
]
headboxes_src
=
gt_dict
[
'headboxes_src'
]
pos_pred_src
=
np
.
transpose
(
preds
,
[
1
,
2
,
0
])
head
=
np
.
where
(
dataset_joints
==
'head'
)[
1
][
0
]
lsho
=
np
.
where
(
dataset_joints
==
'lsho'
)[
1
][
0
]
lelb
=
np
.
where
(
dataset_joints
==
'lelb'
)[
1
][
0
]
lwri
=
np
.
where
(
dataset_joints
==
'lwri'
)[
1
][
0
]
lhip
=
np
.
where
(
dataset_joints
==
'lhip'
)[
1
][
0
]
lkne
=
np
.
where
(
dataset_joints
==
'lkne'
)[
1
][
0
]
lank
=
np
.
where
(
dataset_joints
==
'lank'
)[
1
][
0
]
rsho
=
np
.
where
(
dataset_joints
==
'rsho'
)[
1
][
0
]
relb
=
np
.
where
(
dataset_joints
==
'relb'
)[
1
][
0
]
rwri
=
np
.
where
(
dataset_joints
==
'rwri'
)[
1
][
0
]
rkne
=
np
.
where
(
dataset_joints
==
'rkne'
)[
1
][
0
]
rank
=
np
.
where
(
dataset_joints
==
'rank'
)[
1
][
0
]
rhip
=
np
.
where
(
dataset_joints
==
'rhip'
)[
1
][
0
]
jnt_visible
=
1
-
jnt_missing
uv_error
=
pos_pred_src
-
pos_gt_src
uv_err
=
np
.
linalg
.
norm
(
uv_error
,
axis
=
1
)
headsizes
=
headboxes_src
[
1
,
:,
:]
-
headboxes_src
[
0
,
:,
:]
headsizes
=
np
.
linalg
.
norm
(
headsizes
,
axis
=
0
)
headsizes
*=
SC_BIAS
scale
=
headsizes
*
np
.
ones
((
len
(
uv_err
),
1
),
dtype
=
np
.
float32
)
scaled_uv_err
=
uv_err
/
scale
scaled_uv_err
=
scaled_uv_err
*
jnt_visible
jnt_count
=
np
.
sum
(
jnt_visible
,
axis
=
1
)
less_than_threshold
=
(
scaled_uv_err
<=
threshold
)
*
jnt_visible
PCKh
=
100.
*
np
.
sum
(
less_than_threshold
,
axis
=
1
)
/
jnt_count
# save
rng
=
np
.
arange
(
0
,
0.5
+
0.01
,
0.01
)
pckAll
=
np
.
zeros
((
len
(
rng
),
16
),
dtype
=
np
.
float32
)
for
r
,
threshold
in
enumerate
(
rng
):
less_than_threshold
=
(
scaled_uv_err
<=
threshold
)
*
jnt_visible
pckAll
[
r
,
:]
=
100.
*
np
.
sum
(
less_than_threshold
,
axis
=
1
)
/
jnt_count
PCKh
=
np
.
ma
.
array
(
PCKh
,
mask
=
False
)
PCKh
.
mask
[
6
:
8
]
=
True
jnt_count
=
np
.
ma
.
array
(
jnt_count
,
mask
=
False
)
jnt_count
.
mask
[
6
:
8
]
=
True
jnt_ratio
=
jnt_count
/
np
.
sum
(
jnt_count
).
astype
(
np
.
float64
)
name_value
=
[
#noqa
(
'Head'
,
PCKh
[
head
]),
(
'Shoulder'
,
0.5
*
(
PCKh
[
lsho
]
+
PCKh
[
rsho
])),
(
'Elbow'
,
0.5
*
(
PCKh
[
lelb
]
+
PCKh
[
relb
])),
(
'Wrist'
,
0.5
*
(
PCKh
[
lwri
]
+
PCKh
[
rwri
])),
(
'Hip'
,
0.5
*
(
PCKh
[
lhip
]
+
PCKh
[
rhip
])),
(
'Knee'
,
0.5
*
(
PCKh
[
lkne
]
+
PCKh
[
rkne
])),
(
'Ankle'
,
0.5
*
(
PCKh
[
lank
]
+
PCKh
[
rank
])),
(
'PCKh'
,
np
.
sum
(
PCKh
*
jnt_ratio
)),
(
'PCKh@0.1'
,
np
.
sum
(
pckAll
[
11
,
:]
*
jnt_ratio
))
]
name_value
=
OrderedDict
(
name_value
)
return
name_value
def
_sort_and_unique_bboxes
(
self
,
kpts
,
key
=
'bbox_id'
):
"""sort kpts and remove the repeated ones."""
kpts
=
sorted
(
kpts
,
key
=
lambda
x
:
x
[
key
])
num
=
len
(
kpts
)
for
i
in
range
(
num
-
1
,
0
,
-
1
):
if
kpts
[
i
][
key
]
==
kpts
[
i
-
1
][
key
]:
del
kpts
[
i
]
return
kpts
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录