Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
e0c027f3
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e0c027f3
编写于
5月 23, 2022
作者:
C
chenjian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add pp_tinypose
上级
7f9274d9
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
3191 addition
and
0 deletion
+3191
-0
modules/image/keypoint_detection/pp-tinypose/README.md
modules/image/keypoint_detection/pp-tinypose/README.md
+136
-0
modules/image/keypoint_detection/pp-tinypose/__init__.py
modules/image/keypoint_detection/pp-tinypose/__init__.py
+5
-0
modules/image/keypoint_detection/pp-tinypose/benchmark_utils.py
...s/image/keypoint_detection/pp-tinypose/benchmark_utils.py
+262
-0
modules/image/keypoint_detection/pp-tinypose/det_keypoint_unite_infer.py
...eypoint_detection/pp-tinypose/det_keypoint_unite_infer.py
+230
-0
modules/image/keypoint_detection/pp-tinypose/det_keypoint_unite_utils.py
...eypoint_detection/pp-tinypose/det_keypoint_unite_utils.py
+86
-0
modules/image/keypoint_detection/pp-tinypose/infer.py
modules/image/keypoint_detection/pp-tinypose/infer.py
+694
-0
modules/image/keypoint_detection/pp-tinypose/keypoint_infer.py
...es/image/keypoint_detection/pp-tinypose/keypoint_infer.py
+381
-0
modules/image/keypoint_detection/pp-tinypose/keypoint_postprocess.py
...ge/keypoint_detection/pp-tinypose/keypoint_postprocess.py
+192
-0
modules/image/keypoint_detection/pp-tinypose/keypoint_preprocess.py
...age/keypoint_detection/pp-tinypose/keypoint_preprocess.py
+232
-0
modules/image/keypoint_detection/pp-tinypose/logger.py
modules/image/keypoint_detection/pp-tinypose/logger.py
+68
-0
modules/image/keypoint_detection/pp-tinypose/module.py
modules/image/keypoint_detection/pp-tinypose/module.py
+148
-0
modules/image/keypoint_detection/pp-tinypose/preprocess.py
modules/image/keypoint_detection/pp-tinypose/preprocess.py
+332
-0
modules/image/keypoint_detection/pp-tinypose/utils.py
modules/image/keypoint_detection/pp-tinypose/utils.py
+217
-0
modules/image/keypoint_detection/pp-tinypose/visualize.py
modules/image/keypoint_detection/pp-tinypose/visualize.py
+208
-0
未找到文件。
modules/image/keypoint_detection/pp-tinypose/README.md
0 → 100644
浏览文件 @
e0c027f3
# pp-tinypose
|模型名称|pp-tinypose|
| :--- | :---: |
|类别|图像-关键点检测|
|网络|PicoDet + HRNet|
|数据集|COCO + AI Challenger|
|是否支持Fine-tuning|否|
|模型大小|125M|
|最新更新日期|2022-05-20|
|数据指标|-|
## 一、模型基本信息
-
### 应用效果展示
-
样例结果示例:
<p
align=
"center"
>
<img
src=
"https://user-images.githubusercontent.com/22424850/169768593-9fcf729a-458e-4bb1-bb3c-b005ff7bcec2.jpg"
hspace=
'10'
/>
<br
/>
输入图像
<br
/>
<img
src=
"https://user-images.githubusercontent.com/22424850/169768604-d23a1851-c18b-4f9f-a8ab-2c3f3080e393.jpg"
hspace=
'10'
/>
<br
/>
输出图像
-
### 模型介绍
-
PP-TinyPose是PaddleDetecion针对移动端设备优化的实时关键点检测模型,可流畅地在移动端设备上执行多人姿态估计任务。借助PaddleDetecion自研的优秀轻量级检测模型PicoDet以及轻量级姿态估计任务骨干网络HRNet, 结合多种策略有效平衡了模型的速度和精度表现。
-
更多详情参考:
[
PP-TinyPose
](
https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/keypoint/tiny_pose
)
。
## 二、安装
-
### 1、环境依赖
-
paddlepaddle >= 2.2
-
paddlehub >= 2.2 |
[
如何安装paddlehub
](
../../../../docs/docs_ch/get_start/installation.rst
)
-
### 2、安装
-
```shell
$ hub install pp-tinypose
```
-
如您安装时遇到问题,可参考:
[
零基础windows安装
](
../../../../docs/docs_ch/get_start/windows_quickstart.md
)
|
[
零基础Linux安装
](
../../../../docs/docs_ch/get_start/linux_quickstart.md
)
|
[
零基础MacOS安装
](
../../../../docs/docs_ch/get_start/mac_quickstart.md
)
## 三、模型API预测
-
### 1、命令行预测
-
```shell
$ hub run pp-tinypose --input_path "/PATH/TO/IMAGE" --visualization True --use_gpu
```
-
通过命令行方式实现关键点检测模型的调用,更多请见
[
PaddleHub命令行指令
](
../../../../docs/docs_ch/tutorial/cmd_usage.rst
)
-
### 2、代码示例
-
```python
import paddlehub as hub
import cv2
model = hub.Module(name="pp-tinypose")
result = model.predict('/PATH/TO/IMAGE', save_path='pp_tinypose_output', visualization=True, use_gpu=True)
```
-
### 3、API
-
```python
def predict(self, img: Union[str, np.ndarray], save_path: str = "pp_tinypose_output", visualization: bool = True, use_gpu = False)
```
- 预测API,识别输入图片中的所有人肢体关键点。
- **参数**
- img (numpy.ndarray|str): 图片数据,使用图片路径或者输入numpy.ndarray,BGR格式;
- save_path (str): 图片保存路径, 默认为'pp_tinypose_output';
- visualization (bool): 是否将识别结果保存为图片文件;
- use_gpu: 是否使用gpu;
- **返回**
- res (list\[dict\]): 识别结果的列表,列表元素依然为列表,存的内容为[图像名称,检测框,关键点]。
## 四、服务部署
-
PaddleHub Serving 可以部署一个关键点检测的在线服务。
-
### 第一步:启动PaddleHub Serving
-
运行启动命令:
-
```shell
$ hub serving start -m pp-tinypose
```
-
这样就完成了一个关键点检测的服务化API的部署,默认端口号为8866。
-
**NOTE:**
如使用GPU预测,则需要在启动服务之前,请设置CUDA
\_
VISIBLE
\_
DEVICES环境变量,否则不用设置。
-
### 第二步:发送预测请求
-
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
-
```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/pp-tinypose"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
```
## 五、更新历史
*
1.0.0
初始发布
-
```shell
$ hub install pp-tinypose==1.0.0
```
modules/image/keypoint_detection/pp-tinypose/__init__.py
0 → 100644
浏览文件 @
e0c027f3
import
os
import
sys
CUR_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
CUR_DIR
)
modules/image/keypoint_detection/pp-tinypose/benchmark_utils.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
from
pathlib
import
Path
import
paddle
import
paddle.inference
as
paddle_infer
CUR_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
LOG_PATH_ROOT
=
f
"
{
CUR_DIR
}
/../../output"
class
PaddleInferBenchmark
(
object
):
def
__init__
(
self
,
config
,
model_info
:
dict
=
{},
data_info
:
dict
=
{},
perf_info
:
dict
=
{},
resource_info
:
dict
=
{},
**
kwargs
):
"""
Construct PaddleInferBenchmark Class to format logs.
args:
config(paddle.inference.Config): paddle inference config
model_info(dict): basic model info
{'model_name': 'resnet50'
'precision': 'fp32'}
data_info(dict): input data info
{'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info(dict): performance result
{'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info(dict):
cpu and gpu resources
{'cpu_rss': 100
'gpu_rss': 100
'gpu_util': 60}
"""
# PaddleInferBenchmark Log Version
self
.
log_version
=
"1.0.3"
# Paddle Version
self
.
paddle_version
=
paddle
.
__version__
self
.
paddle_commit
=
paddle
.
__git_commit__
paddle_infer_info
=
paddle_infer
.
get_version
()
self
.
paddle_branch
=
paddle_infer_info
.
strip
().
split
(
': '
)[
-
1
]
# model info
self
.
model_info
=
model_info
# data info
self
.
data_info
=
data_info
# perf info
self
.
perf_info
=
perf_info
try
:
# required value
self
.
model_name
=
model_info
[
'model_name'
]
self
.
precision
=
model_info
[
'precision'
]
self
.
batch_size
=
data_info
[
'batch_size'
]
self
.
shape
=
data_info
[
'shape'
]
self
.
data_num
=
data_info
[
'data_num'
]
self
.
inference_time_s
=
round
(
perf_info
[
'inference_time_s'
],
4
)
except
:
self
.
print_help
()
raise
ValueError
(
"Set argument wrong, please check input argument and its type"
)
self
.
preprocess_time_s
=
perf_info
.
get
(
'preprocess_time_s'
,
0
)
self
.
postprocess_time_s
=
perf_info
.
get
(
'postprocess_time_s'
,
0
)
self
.
with_tracker
=
True
if
'tracking_time_s'
in
perf_info
else
False
self
.
tracking_time_s
=
perf_info
.
get
(
'tracking_time_s'
,
0
)
self
.
total_time_s
=
perf_info
.
get
(
'total_time_s'
,
0
)
self
.
inference_time_s_90
=
perf_info
.
get
(
"inference_time_s_90"
,
""
)
self
.
inference_time_s_99
=
perf_info
.
get
(
"inference_time_s_99"
,
""
)
self
.
succ_rate
=
perf_info
.
get
(
"succ_rate"
,
""
)
self
.
qps
=
perf_info
.
get
(
"qps"
,
""
)
# conf info
self
.
config_status
=
self
.
parse_config
(
config
)
# mem info
if
isinstance
(
resource_info
,
dict
):
self
.
cpu_rss_mb
=
int
(
resource_info
.
get
(
'cpu_rss_mb'
,
0
))
self
.
cpu_vms_mb
=
int
(
resource_info
.
get
(
'cpu_vms_mb'
,
0
))
self
.
cpu_shared_mb
=
int
(
resource_info
.
get
(
'cpu_shared_mb'
,
0
))
self
.
cpu_dirty_mb
=
int
(
resource_info
.
get
(
'cpu_dirty_mb'
,
0
))
self
.
cpu_util
=
round
(
resource_info
.
get
(
'cpu_util'
,
0
),
2
)
self
.
gpu_rss_mb
=
int
(
resource_info
.
get
(
'gpu_rss_mb'
,
0
))
self
.
gpu_util
=
round
(
resource_info
.
get
(
'gpu_util'
,
0
),
2
)
self
.
gpu_mem_util
=
round
(
resource_info
.
get
(
'gpu_mem_util'
,
0
),
2
)
else
:
self
.
cpu_rss_mb
=
0
self
.
cpu_vms_mb
=
0
self
.
cpu_shared_mb
=
0
self
.
cpu_dirty_mb
=
0
self
.
cpu_util
=
0
self
.
gpu_rss_mb
=
0
self
.
gpu_util
=
0
self
.
gpu_mem_util
=
0
# init benchmark logger
self
.
benchmark_logger
()
def
benchmark_logger
(
self
):
"""
benchmark logger
"""
# remove other logging handler
for
handler
in
logging
.
root
.
handlers
[:]:
logging
.
root
.
removeHandler
(
handler
)
# Init logger
FORMAT
=
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
log_output
=
f
"
{
LOG_PATH_ROOT
}
/
{
self
.
model_name
}
.log"
Path
(
f
"
{
LOG_PATH_ROOT
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
,
handlers
=
[
logging
.
FileHandler
(
filename
=
log_output
,
mode
=
'w'
),
logging
.
StreamHandler
(),
])
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
logger
.
info
(
f
"Paddle Inference benchmark log will be saved to
{
log_output
}
"
)
def
parse_config
(
self
,
config
)
->
dict
:
"""
parse paddle predictor config
args:
config(paddle.inference.Config): paddle inference config
return:
config_status(dict): dict style config info
"""
if
isinstance
(
config
,
paddle_infer
.
Config
):
config_status
=
{}
config_status
[
'runtime_device'
]
=
"gpu"
if
config
.
use_gpu
()
else
"cpu"
config_status
[
'ir_optim'
]
=
config
.
ir_optim
()
config_status
[
'enable_tensorrt'
]
=
config
.
tensorrt_engine_enabled
()
config_status
[
'precision'
]
=
self
.
precision
config_status
[
'enable_mkldnn'
]
=
config
.
mkldnn_enabled
()
config_status
[
'cpu_math_library_num_threads'
]
=
config
.
cpu_math_library_num_threads
()
elif
isinstance
(
config
,
dict
):
config_status
[
'runtime_device'
]
=
config
.
get
(
'runtime_device'
,
""
)
config_status
[
'ir_optim'
]
=
config
.
get
(
'ir_optim'
,
""
)
config_status
[
'enable_tensorrt'
]
=
config
.
get
(
'enable_tensorrt'
,
""
)
config_status
[
'precision'
]
=
config
.
get
(
'precision'
,
""
)
config_status
[
'enable_mkldnn'
]
=
config
.
get
(
'enable_mkldnn'
,
""
)
config_status
[
'cpu_math_library_num_threads'
]
=
config
.
get
(
'cpu_math_library_num_threads'
,
""
)
else
:
self
.
print_help
()
raise
ValueError
(
"Set argument config wrong, please check input argument and its type"
)
return
config_status
def
report
(
self
,
identifier
=
None
):
"""
print log report
args:
identifier(string): identify log
"""
if
identifier
:
identifier
=
f
"[
{
identifier
}
]"
else
:
identifier
=
""
self
.
logger
.
info
(
"
\n
"
)
self
.
logger
.
info
(
"---------------------- Paddle info ----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
paddle_version:
{
self
.
paddle_version
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
paddle_commit:
{
self
.
paddle_commit
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
paddle_branch:
{
self
.
paddle_branch
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
log_api_version:
{
self
.
log_version
}
"
)
self
.
logger
.
info
(
"----------------------- Conf info -----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
runtime_device:
{
self
.
config_status
[
'runtime_device'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
ir_optim:
{
self
.
config_status
[
'ir_optim'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
enable_memory_optim:
{
True
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
enable_tensorrt:
{
self
.
config_status
[
'enable_tensorrt'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
enable_mkldnn:
{
self
.
config_status
[
'enable_mkldnn'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
cpu_math_library_num_threads:
{
self
.
config_status
[
'cpu_math_library_num_threads'
]
}
"
)
self
.
logger
.
info
(
"----------------------- Model info ----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
model_name:
{
self
.
model_name
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
precision:
{
self
.
precision
}
"
)
self
.
logger
.
info
(
"----------------------- Data info -----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
batch_size:
{
self
.
batch_size
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
input_shape:
{
self
.
shape
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
data_num:
{
self
.
data_num
}
"
)
self
.
logger
.
info
(
"----------------------- Perf info -----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
cpu_rss(MB):
{
self
.
cpu_rss_mb
}
, cpu_vms:
{
self
.
cpu_vms_mb
}
, cpu_shared_mb:
{
self
.
cpu_shared_mb
}
, cpu_dirty_mb:
{
self
.
cpu_dirty_mb
}
, cpu_util:
{
self
.
cpu_util
}
%"
)
self
.
logger
.
info
(
f
"
{
identifier
}
gpu_rss(MB):
{
self
.
gpu_rss_mb
}
, gpu_util:
{
self
.
gpu_util
}
%, gpu_mem_util:
{
self
.
gpu_mem_util
}
%"
)
self
.
logger
.
info
(
f
"
{
identifier
}
total time spent(s):
{
self
.
total_time_s
}
"
)
if
self
.
with_tracker
:
self
.
logger
.
info
(
f
"
{
identifier
}
preprocess_time(ms):
{
round
(
self
.
preprocess_time_s
*
1000
,
1
)
}
, "
f
"inference_time(ms):
{
round
(
self
.
inference_time_s
*
1000
,
1
)
}
, "
f
"postprocess_time(ms):
{
round
(
self
.
postprocess_time_s
*
1000
,
1
)
}
, "
f
"tracking_time(ms):
{
round
(
self
.
tracking_time_s
*
1000
,
1
)
}
"
)
else
:
self
.
logger
.
info
(
f
"
{
identifier
}
preprocess_time(ms):
{
round
(
self
.
preprocess_time_s
*
1000
,
1
)
}
, "
f
"inference_time(ms):
{
round
(
self
.
inference_time_s
*
1000
,
1
)
}
, "
f
"postprocess_time(ms):
{
round
(
self
.
postprocess_time_s
*
1000
,
1
)
}
"
)
if
self
.
inference_time_s_90
:
self
.
looger
.
info
(
f
"
{
identifier
}
90%_cost:
{
self
.
inference_time_s_90
}
, 99%_cost:
{
self
.
inference_time_s_99
}
, succ_rate:
{
self
.
succ_rate
}
"
)
if
self
.
qps
:
self
.
logger
.
info
(
f
"
{
identifier
}
QPS:
{
self
.
qps
}
"
)
def
print_help
(
self
):
"""
print function help
"""
print
(
"""Usage:
==== Print inference benchmark logs. ====
config = paddle.inference.Config()
model_info = {'model_name': 'resnet50'
'precision': 'fp32'}
data_info = {'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info = {'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info = {'cpu_rss_mb': 100
'gpu_rss_mb': 100
'gpu_util': 60}
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
log('Test')
"""
)
def
__call__
(
self
,
identifier
=
None
):
"""
__call__
args:
identifier(string): identify log
"""
self
.
report
(
identifier
)
modules/image/keypoint_detection/pp-tinypose/det_keypoint_unite_infer.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
math
import
os
import
cv2
import
numpy
as
np
import
paddle
import
yaml
from
benchmark_utils
import
PaddleInferBenchmark
from
det_keypoint_unite_utils
import
argsparser
from
infer
import
bench_log
from
infer
import
Detector
from
infer
import
get_test_images
from
infer
import
PredictConfig
from
infer
import
print_arguments
from
keypoint_infer
import
KeyPointDetector
from
keypoint_infer
import
PredictConfig_KeyPoint
from
keypoint_postprocess
import
translate_to_ori_images
from
preprocess
import
decode_image
from
utils
import
get_current_memory_mb
from
visualize
import
visualize_pose
KEYPOINT_SUPPORT_MODELS
=
{
'HigherHRNet'
:
'keypoint_bottomup'
,
'HRNet'
:
'keypoint_topdown'
}
def
predict_with_given_det
(
image
,
det_res
,
keypoint_detector
,
keypoint_batch_size
,
run_benchmark
):
rec_images
,
records
,
det_rects
=
keypoint_detector
.
get_person_from_rect
(
image
,
det_res
)
keypoint_vector
=
[]
score_vector
=
[]
rect_vector
=
det_rects
keypoint_results
=
keypoint_detector
.
predict_image
(
rec_images
,
run_benchmark
,
repeats
=
10
,
visual
=
False
)
keypoint_vector
,
score_vector
=
translate_to_ori_images
(
keypoint_results
,
np
.
array
(
records
))
keypoint_res
=
{}
keypoint_res
[
'keypoint'
]
=
[
keypoint_vector
.
tolist
(),
score_vector
.
tolist
()]
if
len
(
keypoint_vector
)
>
0
else
[[],
[]]
keypoint_res
[
'bbox'
]
=
rect_vector
return
keypoint_res
def
topdown_unite_predict
(
detector
,
topdown_keypoint_detector
,
image_list
,
keypoint_batch_size
=
1
,
save_res
=
False
):
det_timer
=
detector
.
get_timer
()
store_res
=
[]
for
i
,
img_file
in
enumerate
(
image_list
):
# Decode image in advance in det + pose prediction
det_timer
.
preprocess_time_s
.
start
()
image
,
_
=
decode_image
(
img_file
,
{})
det_timer
.
preprocess_time_s
.
end
()
if
FLAGS
.
run_benchmark
:
results
=
detector
.
predict_image
([
image
],
run_benchmark
=
True
,
repeats
=
10
)
cm
,
gm
,
gu
=
get_current_memory_mb
()
detector
.
cpu_mem
+=
cm
detector
.
gpu_mem
+=
gm
detector
.
gpu_util
+=
gu
else
:
results
=
detector
.
predict_image
([
image
],
visual
=
False
)
results
=
detector
.
filter_box
(
results
,
FLAGS
.
det_threshold
)
if
results
[
'boxes_num'
]
>
0
:
keypoint_res
=
predict_with_given_det
(
image
,
results
,
topdown_keypoint_detector
,
keypoint_batch_size
,
FLAGS
.
run_benchmark
)
if
save_res
:
save_name
=
img_file
if
isinstance
(
img_file
,
str
)
else
i
store_res
.
append
(
[
save_name
,
keypoint_res
[
'bbox'
],
[
keypoint_res
[
'keypoint'
][
0
],
keypoint_res
[
'keypoint'
][
1
]]])
else
:
results
[
"keypoint"
]
=
[[],
[]]
keypoint_res
=
results
if
FLAGS
.
run_benchmark
:
cm
,
gm
,
gu
=
get_current_memory_mb
()
topdown_keypoint_detector
.
cpu_mem
+=
cm
topdown_keypoint_detector
.
gpu_mem
+=
gm
topdown_keypoint_detector
.
gpu_util
+=
gu
else
:
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
visualize_pose
(
img_file
,
keypoint_res
,
visual_thresh
=
FLAGS
.
keypoint_threshold
,
save_dir
=
FLAGS
.
output_dir
)
if
save_res
:
"""
1) store_res: a list of image_data
2) image_data: [imageid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with
open
(
"det_keypoint_unite_image_results.json"
,
'w'
)
as
wf
:
json
.
dump
(
store_res
,
wf
,
indent
=
4
)
def
topdown_unite_predict_video
(
detector
,
topdown_keypoint_detector
,
camera_id
,
keypoint_batch_size
=
1
,
save_res
=
False
):
video_name
=
'output.mp4'
if
camera_id
!=
-
1
:
capture
=
cv2
.
VideoCapture
(
camera_id
)
else
:
capture
=
cv2
.
VideoCapture
(
FLAGS
.
video_file
)
video_name
=
os
.
path
.
split
(
FLAGS
.
video_file
)[
-
1
]
# Get Video info : resolution, fps, frame count
width
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
fps
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FPS
))
frame_count
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
print
(
"fps: %d, frame_count: %d"
%
(
fps
,
frame_count
))
if
not
os
.
path
.
exists
(
FLAGS
.
output_dir
):
os
.
makedirs
(
FLAGS
.
output_dir
)
out_path
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
video_name
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
index
=
0
store_res
=
[]
while
(
1
):
ret
,
frame
=
capture
.
read
()
if
not
ret
:
break
index
+=
1
print
(
'detect frame: %d'
%
(
index
))
frame2
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
)
results
=
detector
.
predict_image
([
frame2
],
visual
=
False
)
results
=
detector
.
filter_box
(
results
,
FLAGS
.
det_threshold
)
if
results
[
'boxes_num'
]
==
0
:
writer
.
write
(
frame
)
continue
keypoint_res
=
predict_with_given_det
(
frame2
,
results
,
topdown_keypoint_detector
,
keypoint_batch_size
,
FLAGS
.
run_benchmark
)
im
=
visualize_pose
(
frame
,
keypoint_res
,
visual_thresh
=
FLAGS
.
keypoint_threshold
,
returnimg
=
True
)
if
save_res
:
store_res
.
append
([
index
,
keypoint_res
[
'bbox'
],
[
keypoint_res
[
'keypoint'
][
0
],
keypoint_res
[
'keypoint'
][
1
]]])
writer
.
write
(
im
)
if
camera_id
!=
-
1
:
cv2
.
imshow
(
'Mask Detection'
,
im
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
writer
.
release
()
print
(
'output_video saved to: {}'
.
format
(
out_path
))
if
save_res
:
"""
1) store_res: a list of frame_data
2) frame_data: [frameid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with
open
(
"det_keypoint_unite_video_results.json"
,
'w'
)
as
wf
:
json
.
dump
(
store_res
,
wf
,
indent
=
4
)
def
main
():
deploy_file
=
os
.
path
.
join
(
FLAGS
.
det_model_dir
,
'infer_cfg.yml'
)
with
open
(
deploy_file
)
as
f
:
yml_conf
=
yaml
.
safe_load
(
f
)
arch
=
yml_conf
[
'arch'
]
detector
=
Detector
(
FLAGS
.
det_model_dir
,
device
=
FLAGS
.
device
,
run_mode
=
FLAGS
.
run_mode
,
trt_min_shape
=
FLAGS
.
trt_min_shape
,
trt_max_shape
=
FLAGS
.
trt_max_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
threshold
=
FLAGS
.
det_threshold
)
topdown_keypoint_detector
=
KeyPointDetector
(
FLAGS
.
keypoint_model_dir
,
device
=
FLAGS
.
device
,
run_mode
=
FLAGS
.
run_mode
,
batch_size
=
FLAGS
.
keypoint_batch_size
,
trt_min_shape
=
FLAGS
.
trt_min_shape
,
trt_max_shape
=
FLAGS
.
trt_max_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
use_dark
=
FLAGS
.
use_dark
)
keypoint_arch
=
topdown_keypoint_detector
.
pred_config
.
arch
assert
KEYPOINT_SUPPORT_MODELS
[
keypoint_arch
]
==
'keypoint_topdown'
,
'Detection-Keypoint unite inference only supports topdown models.'
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
topdown_unite_predict_video
(
detector
,
topdown_keypoint_detector
,
FLAGS
.
camera_id
,
FLAGS
.
keypoint_batch_size
,
FLAGS
.
save_res
)
else
:
# predict from image
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
topdown_unite_predict
(
detector
,
topdown_keypoint_detector
,
img_list
,
FLAGS
.
keypoint_batch_size
,
FLAGS
.
save_res
)
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
topdown_keypoint_detector
.
det_times
.
info
(
average
=
True
)
else
:
mode
=
FLAGS
.
run_mode
det_model_dir
=
FLAGS
.
det_model_dir
det_model_info
=
{
'model_name'
:
det_model_dir
.
strip
(
'/'
).
split
(
'/'
)[
-
1
],
'precision'
:
mode
.
split
(
'_'
)[
-
1
]}
bench_log
(
detector
,
img_list
,
det_model_info
,
name
=
'Det'
)
keypoint_model_dir
=
FLAGS
.
keypoint_model_dir
keypoint_model_info
=
{
'model_name'
:
keypoint_model_dir
.
strip
(
'/'
).
split
(
'/'
)[
-
1
],
'precision'
:
mode
.
split
(
'_'
)[
-
1
]
}
bench_log
(
topdown_keypoint_detector
,
img_list
,
keypoint_model_info
,
FLAGS
.
keypoint_batch_size
,
'KeyPoint'
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
FLAGS
.
device
=
FLAGS
.
device
.
upper
()
assert
FLAGS
.
device
in
[
'CPU'
,
'GPU'
,
'XPU'
],
"device should be CPU, GPU or XPU"
main
()
modules/image/keypoint_detection/pp-tinypose/det_keypoint_unite_utils.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
),
required
=
True
)
parser
.
add_argument
(
"--keypoint_model_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
),
required
=
True
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
None
,
help
=
"Dir of image file, `image_file` has a higher priority."
)
parser
.
add_argument
(
"--keypoint_batch_size"
,
type
=
int
,
default
=
8
,
help
=
(
"batch_size for keypoint inference. In detection-keypoint unit"
"inference, the batch size in detection is 1. Then collate det "
"result in batch for keypoint inference."
))
parser
.
add_argument
(
"--video_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser
.
add_argument
(
"--camera_id"
,
type
=
int
,
default
=-
1
,
help
=
"device id of camera to predict."
)
parser
.
add_argument
(
"--det_threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Threshold of score."
)
parser
.
add_argument
(
"--keypoint_threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Threshold of score."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Directory of output visualization files."
)
parser
.
add_argument
(
"--run_mode"
,
type
=
str
,
default
=
'paddle'
,
help
=
"mode of running(paddle/trt_fp32/trt_fp16/trt_int8)"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'cpu'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser
.
add_argument
(
"--run_benchmark"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether to predict a image_file repeatedly for benchmark"
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn with CPU."
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
1
,
help
=
"Num of threads with CPU."
)
parser
.
add_argument
(
"--trt_min_shape"
,
type
=
int
,
default
=
1
,
help
=
"min_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_max_shape"
,
type
=
int
,
default
=
1280
,
help
=
"max_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_opt_shape"
,
type
=
int
,
default
=
640
,
help
=
"opt_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_calib_mode"
,
type
=
bool
,
default
=
False
,
help
=
"If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True."
)
parser
.
add_argument
(
'--use_dark'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'whether to use darkpose to get better keypoint position predict '
)
parser
.
add_argument
(
'--save_res'
,
type
=
bool
,
default
=
False
,
help
=
(
"whether to save predict results to json file"
"1) store_res: a list of image_data"
"2) image_data: [imageid, rects, [keypoints, scores]]"
"3) rects: list of rect [xmin, ymin, xmax, ymax]"
"4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list"
"5) scores: mean of all joint conf"
))
return
parser
modules/image/keypoint_detection/pp-tinypose/infer.py
0 → 100644
浏览文件 @
e0c027f3
此差异已折叠。
点击以展开。
modules/image/keypoint_detection/pp-tinypose/keypoint_infer.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
glob
import
math
import
os
import
sys
import
time
from
functools
import
reduce
import
cv2
import
numpy
as
np
import
paddle
import
yaml
from
PIL
import
Image
# add deploy path of PadleDetection to sys.path
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
])))
sys
.
path
.
insert
(
0
,
parent_path
)
from
preprocess
import
preprocess
,
NormalizeImage
,
Permute
from
keypoint_preprocess
import
EvalAffine
,
TopDownEvalAffine
,
expand_crop
from
keypoint_postprocess
import
HRNetPostProcess
from
visualize
import
visualize_pose
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
utils
import
argsparser
,
Timer
,
get_current_memory_mb
from
benchmark_utils
import
PaddleInferBenchmark
from
infer
import
Detector
,
get_test_images
,
print_arguments
# Global dictionary
KEYPOINT_SUPPORT_MODELS
=
{
'HigherHRNet'
:
'keypoint_bottomup'
,
'HRNet'
:
'keypoint_topdown'
}
class
KeyPointDetector
(
Detector
):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
use_dark(bool): whether to use postprocess in DarkPose
"""
def
__init__
(
self
,
model_dir
,
device
=
'CPU'
,
run_mode
=
'paddle'
,
batch_size
=
1
,
trt_min_shape
=
1
,
trt_max_shape
=
1280
,
trt_opt_shape
=
640
,
trt_calib_mode
=
False
,
cpu_threads
=
1
,
enable_mkldnn
=
False
,
output_dir
=
'output'
,
threshold
=
0.5
,
use_dark
=
True
):
super
(
KeyPointDetector
,
self
).
__init__
(
model_dir
=
model_dir
,
device
=
device
,
run_mode
=
run_mode
,
batch_size
=
batch_size
,
trt_min_shape
=
trt_min_shape
,
trt_max_shape
=
trt_max_shape
,
trt_opt_shape
=
trt_opt_shape
,
trt_calib_mode
=
trt_calib_mode
,
cpu_threads
=
cpu_threads
,
enable_mkldnn
=
enable_mkldnn
,
output_dir
=
output_dir
,
threshold
=
threshold
,
)
self
.
use_dark
=
use_dark
def
set_config
(
self
,
model_dir
):
return
PredictConfig_KeyPoint
(
model_dir
)
def
get_person_from_rect
(
self
,
image
,
results
):
# crop the person result from image
self
.
det_times
.
preprocess_time_s
.
start
()
valid_rects
=
results
[
'boxes'
]
rect_images
=
[]
new_rects
=
[]
org_rects
=
[]
for
rect
in
valid_rects
:
rect_image
,
new_rect
,
org_rect
=
expand_crop
(
image
,
rect
)
if
rect_image
is
None
or
rect_image
.
size
==
0
:
continue
rect_images
.
append
(
rect_image
)
new_rects
.
append
(
new_rect
)
org_rects
.
append
(
org_rect
)
self
.
det_times
.
preprocess_time_s
.
end
()
return
rect_images
,
new_rects
,
org_rects
def
postprocess
(
self
,
inputs
,
result
):
np_heatmap
=
result
[
'heatmap'
]
np_masks
=
result
[
'masks'
]
# postprocess output of predictor
if
KEYPOINT_SUPPORT_MODELS
[
self
.
pred_config
.
arch
]
==
'keypoint_bottomup'
:
results
=
{}
h
,
w
=
inputs
[
'im_shape'
][
0
]
preds
=
[
np_heatmap
]
if
np_masks
is
not
None
:
preds
+=
np_masks
preds
+=
[
h
,
w
]
keypoint_postprocess
=
HRNetPostProcess
()
kpts
,
scores
=
keypoint_postprocess
(
*
preds
)
results
[
'keypoint'
]
=
kpts
results
[
'score'
]
=
scores
return
results
elif
KEYPOINT_SUPPORT_MODELS
[
self
.
pred_config
.
arch
]
==
'keypoint_topdown'
:
results
=
{}
imshape
=
inputs
[
'im_shape'
][:,
::
-
1
]
center
=
np
.
round
(
imshape
/
2.
)
scale
=
imshape
/
200.
keypoint_postprocess
=
HRNetPostProcess
(
use_dark
=
self
.
use_dark
)
kpts
,
scores
=
keypoint_postprocess
(
np_heatmap
,
center
,
scale
)
results
[
'keypoint'
]
=
kpts
results
[
'score'
]
=
scores
return
results
else
:
raise
ValueError
(
"Unsupported arch: {}, expect {}"
.
format
(
self
.
pred_config
.
arch
,
KEYPOINT_SUPPORT_MODELS
))
def
predict
(
self
,
repeats
=
1
):
'''
Args:
repeats (int): repeat number for prediction
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
# model prediction
np_heatmap
,
np_masks
=
None
,
None
for
i
in
range
(
repeats
):
self
.
predictor
.
run
()
output_names
=
self
.
predictor
.
get_output_names
()
heatmap_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
0
])
np_heatmap
=
heatmap_tensor
.
copy_to_cpu
()
if
self
.
pred_config
.
tagmap
:
masks_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
])
heat_k
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
])
inds_k
=
self
.
predictor
.
get_output_handle
(
output_names
[
3
])
np_masks
=
[
masks_tensor
.
copy_to_cpu
(),
heat_k
.
copy_to_cpu
(),
inds_k
.
copy_to_cpu
()]
result
=
dict
(
heatmap
=
np_heatmap
,
masks
=
np_masks
)
return
result
def
predict_image
(
self
,
image_list
,
run_benchmark
=
False
,
repeats
=
1
,
visual
=
True
):
results
=
[]
batch_loop_cnt
=
math
.
ceil
(
float
(
len
(
image_list
))
/
self
.
batch_size
)
for
i
in
range
(
batch_loop_cnt
):
start_index
=
i
*
self
.
batch_size
end_index
=
min
((
i
+
1
)
*
self
.
batch_size
,
len
(
image_list
))
batch_image_list
=
image_list
[
start_index
:
end_index
]
if
run_benchmark
:
# preprocess
inputs
=
self
.
preprocess
(
batch_image_list
)
# warmup
self
.
det_times
.
preprocess_time_s
.
start
()
inputs
=
self
.
preprocess
(
batch_image_list
)
self
.
det_times
.
preprocess_time_s
.
end
()
# model prediction
result_warmup
=
self
.
predict
(
repeats
=
repeats
)
# warmup
self
.
det_times
.
inference_time_s
.
start
()
result
=
self
.
predict
(
repeats
=
repeats
)
self
.
det_times
.
inference_time_s
.
end
(
repeats
=
repeats
)
# postprocess
result_warmup
=
self
.
postprocess
(
inputs
,
result
)
# warmup
self
.
det_times
.
postprocess_time_s
.
start
()
result
=
self
.
postprocess
(
inputs
,
result
)
self
.
det_times
.
postprocess_time_s
.
end
()
self
.
det_times
.
img_num
+=
len
(
batch_image_list
)
cm
,
gm
,
gu
=
get_current_memory_mb
()
self
.
cpu_mem
+=
cm
self
.
gpu_mem
+=
gm
self
.
gpu_util
+=
gu
else
:
# preprocess
self
.
det_times
.
preprocess_time_s
.
start
()
inputs
=
self
.
preprocess
(
batch_image_list
)
self
.
det_times
.
preprocess_time_s
.
end
()
# model prediction
self
.
det_times
.
inference_time_s
.
start
()
result
=
self
.
predict
()
self
.
det_times
.
inference_time_s
.
end
()
# postprocess
self
.
det_times
.
postprocess_time_s
.
start
()
result
=
self
.
postprocess
(
inputs
,
result
)
self
.
det_times
.
postprocess_time_s
.
end
()
self
.
det_times
.
img_num
+=
len
(
batch_image_list
)
if
visual
:
if
not
os
.
path
.
exists
(
self
.
output_dir
):
os
.
makedirs
(
self
.
output_dir
)
visualize
(
batch_image_list
,
result
,
visual_thresh
=
self
.
threshold
,
save_dir
=
self
.
output_dir
)
results
.
append
(
result
)
if
visual
:
print
(
'Test iter {}'
.
format
(
i
))
results
=
self
.
merge_batch_result
(
results
)
return
results
def
predict_video
(
self
,
video_file
,
camera_id
):
video_name
=
'output.mp4'
if
camera_id
!=
-
1
:
capture
=
cv2
.
VideoCapture
(
camera_id
)
else
:
capture
=
cv2
.
VideoCapture
(
video_file
)
video_name
=
os
.
path
.
split
(
video_file
)[
-
1
]
# Get Video info : resolution, fps, frame count
width
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
fps
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FPS
))
frame_count
=
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
print
(
"fps: %d, frame_count: %d"
%
(
fps
,
frame_count
))
if
not
os
.
path
.
exists
(
self
.
output_dir
):
os
.
makedirs
(
self
.
output_dir
)
out_path
=
os
.
path
.
join
(
self
.
output_dir
,
video_name
)
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'mp4v'
)
writer
=
cv2
.
VideoWriter
(
out_path
,
fourcc
,
fps
,
(
width
,
height
))
index
=
1
while
(
1
):
ret
,
frame
=
capture
.
read
()
if
not
ret
:
break
print
(
'detect frame: %d'
%
(
index
))
index
+=
1
results
=
self
.
predict_image
([
frame
[:,
:,
::
-
1
]],
visual
=
False
)
im_results
=
{}
im_results
[
'keypoint'
]
=
[
results
[
'keypoint'
],
results
[
'score'
]]
im
=
visualize_pose
(
frame
,
im_results
,
visual_thresh
=
self
.
threshold
,
returnimg
=
True
)
writer
.
write
(
im
)
if
camera_id
!=
-
1
:
cv2
.
imshow
(
'Mask Detection'
,
im
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
writer
.
release
()
def
create_inputs
(
imgs
,
im_info
):
"""generate input for different model type
Args:
imgs (list(numpy)): list of image (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs
=
{}
inputs
[
'image'
]
=
np
.
stack
(
imgs
,
axis
=
0
).
astype
(
'float32'
)
im_shape
=
[]
for
e
in
im_info
:
im_shape
.
append
(
np
.
array
((
e
[
'im_shape'
])).
astype
(
'float32'
))
inputs
[
'im_shape'
]
=
np
.
stack
(
im_shape
,
axis
=
0
)
return
inputs
class
PredictConfig_KeyPoint
():
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of model.yml
"""
def
__init__
(
self
,
model_dir
):
# parsing Yaml config for Preprocess
deploy_file
=
os
.
path
.
join
(
model_dir
,
'infer_cfg.yml'
)
with
open
(
deploy_file
)
as
f
:
yml_conf
=
yaml
.
safe_load
(
f
)
self
.
check_model
(
yml_conf
)
self
.
arch
=
yml_conf
[
'arch'
]
self
.
archcls
=
KEYPOINT_SUPPORT_MODELS
[
yml_conf
[
'arch'
]]
self
.
preprocess_infos
=
yml_conf
[
'Preprocess'
]
self
.
min_subgraph_size
=
yml_conf
[
'min_subgraph_size'
]
self
.
labels
=
yml_conf
[
'label_list'
]
self
.
tagmap
=
False
self
.
use_dynamic_shape
=
yml_conf
[
'use_dynamic_shape'
]
if
'keypoint_bottomup'
==
self
.
archcls
:
self
.
tagmap
=
True
self
.
print_config
()
def
check_model
(
self
,
yml_conf
):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for
support_model
in
KEYPOINT_SUPPORT_MODELS
:
if
support_model
in
yml_conf
[
'arch'
]:
return
True
raise
ValueError
(
"Unsupported arch: {}, expect {}"
.
format
(
yml_conf
[
'arch'
],
KEYPOINT_SUPPORT_MODELS
))
def
print_config
(
self
):
print
(
'----------- Model Configuration -----------'
)
print
(
'%s: %s'
%
(
'Model Arch'
,
self
.
arch
))
print
(
'%s: '
%
(
'Transform Order'
))
for
op_info
in
self
.
preprocess_infos
:
print
(
'--%s: %s'
%
(
'transform op'
,
op_info
[
'type'
]))
print
(
'--------------------------------------------'
)
def
visualize
(
image_list
,
results
,
visual_thresh
=
0.6
,
save_dir
=
'output'
):
im_results
=
{}
for
i
,
image_file
in
enumerate
(
image_list
):
skeletons
=
results
[
'keypoint'
]
scores
=
results
[
'score'
]
skeleton
=
skeletons
[
i
:
i
+
1
]
score
=
scores
[
i
:
i
+
1
]
im_results
[
'keypoint'
]
=
[
skeleton
,
score
]
visualize_pose
(
image_file
,
im_results
,
visual_thresh
=
visual_thresh
,
save_dir
=
save_dir
)
def
main
():
detector
=
KeyPointDetector
(
FLAGS
.
model_dir
,
device
=
FLAGS
.
device
,
run_mode
=
FLAGS
.
run_mode
,
batch_size
=
FLAGS
.
batch_size
,
trt_min_shape
=
FLAGS
.
trt_min_shape
,
trt_max_shape
=
FLAGS
.
trt_max_shape
,
trt_opt_shape
=
FLAGS
.
trt_opt_shape
,
trt_calib_mode
=
FLAGS
.
trt_calib_mode
,
cpu_threads
=
FLAGS
.
cpu_threads
,
enable_mkldnn
=
FLAGS
.
enable_mkldnn
,
threshold
=
FLAGS
.
threshold
,
output_dir
=
FLAGS
.
output_dir
,
use_dark
=
FLAGS
.
use_dark
)
# predict from video file or camera video stream
if
FLAGS
.
video_file
is
not
None
or
FLAGS
.
camera_id
!=
-
1
:
detector
.
predict_video
(
FLAGS
.
video_file
,
FLAGS
.
camera_id
)
else
:
# predict from image
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
detector
.
predict_image
(
img_list
,
FLAGS
.
run_benchmark
,
repeats
=
10
)
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
else
:
mems
=
{
'cpu_rss_mb'
:
detector
.
cpu_mem
/
len
(
img_list
),
'gpu_rss_mb'
:
detector
.
gpu_mem
/
len
(
img_list
),
'gpu_util'
:
detector
.
gpu_util
*
100
/
len
(
img_list
)
}
perf_info
=
detector
.
det_times
.
report
(
average
=
True
)
model_dir
=
FLAGS
.
model_dir
mode
=
FLAGS
.
run_mode
model_info
=
{
'model_name'
:
model_dir
.
strip
(
'/'
).
split
(
'/'
)[
-
1
],
'precision'
:
mode
.
split
(
'_'
)[
-
1
]}
data_info
=
{
'batch_size'
:
1
,
'shape'
:
"dynamic_shape"
,
'data_num'
:
perf_info
[
'img_num'
]}
det_log
=
PaddleInferBenchmark
(
detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
det_log
(
'KeyPoint'
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
FLAGS
.
device
=
FLAGS
.
device
.
upper
()
assert
FLAGS
.
device
in
[
'CPU'
,
'GPU'
,
'XPU'
],
"device should be CPU, GPU or XPU"
assert
not
FLAGS
.
use_gpu
,
"use_gpu has been deprecated, please use --device"
main
()
modules/image/keypoint_detection/pp-tinypose/keypoint_postprocess.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
collections
import
abc
from
collections
import
defaultdict
import
cv2
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
keypoint_preprocess
import
get_affine_mat_kernel
from
keypoint_preprocess
import
get_affine_transform
from
scipy.optimize
import
linear_sum_assignment
class
HRNetPostProcess
(
object
):
def
__init__
(
self
,
use_dark
=
True
):
self
.
use_dark
=
use_dark
def
flip_back
(
self
,
output_flipped
,
matched_parts
):
assert
output_flipped
.
ndim
==
4
,
\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped
=
output_flipped
[:,
:,
:,
::
-
1
]
for
pair
in
matched_parts
:
tmp
=
output_flipped
[:,
pair
[
0
],
:,
:].
copy
()
output_flipped
[:,
pair
[
0
],
:,
:]
=
output_flipped
[:,
pair
[
1
],
:,
:]
output_flipped
[:,
pair
[
1
],
:,
:]
=
tmp
return
output_flipped
def
get_max_preds
(
self
,
heatmaps
):
"""get predictions from score maps
Args:
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
"""
assert
isinstance
(
heatmaps
,
np
.
ndarray
),
'heatmaps should be numpy.ndarray'
assert
heatmaps
.
ndim
==
4
,
'batch_images should be 4-ndim'
batch_size
=
heatmaps
.
shape
[
0
]
num_joints
=
heatmaps
.
shape
[
1
]
width
=
heatmaps
.
shape
[
3
]
heatmaps_reshaped
=
heatmaps
.
reshape
((
batch_size
,
num_joints
,
-
1
))
idx
=
np
.
argmax
(
heatmaps_reshaped
,
2
)
maxvals
=
np
.
amax
(
heatmaps_reshaped
,
2
)
maxvals
=
maxvals
.
reshape
((
batch_size
,
num_joints
,
1
))
idx
=
idx
.
reshape
((
batch_size
,
num_joints
,
1
))
preds
=
np
.
tile
(
idx
,
(
1
,
1
,
2
)).
astype
(
np
.
float32
)
preds
[:,
:,
0
]
=
(
preds
[:,
:,
0
])
%
width
preds
[:,
:,
1
]
=
np
.
floor
((
preds
[:,
:,
1
])
/
width
)
pred_mask
=
np
.
tile
(
np
.
greater
(
maxvals
,
0.0
),
(
1
,
1
,
2
))
pred_mask
=
pred_mask
.
astype
(
np
.
float32
)
preds
*=
pred_mask
return
preds
,
maxvals
def
gaussian_blur
(
self
,
heatmap
,
kernel
):
border
=
(
kernel
-
1
)
//
2
batch_size
=
heatmap
.
shape
[
0
]
num_joints
=
heatmap
.
shape
[
1
]
height
=
heatmap
.
shape
[
2
]
width
=
heatmap
.
shape
[
3
]
for
i
in
range
(
batch_size
):
for
j
in
range
(
num_joints
):
origin_max
=
np
.
max
(
heatmap
[
i
,
j
])
dr
=
np
.
zeros
((
height
+
2
*
border
,
width
+
2
*
border
))
dr
[
border
:
-
border
,
border
:
-
border
]
=
heatmap
[
i
,
j
].
copy
()
dr
=
cv2
.
GaussianBlur
(
dr
,
(
kernel
,
kernel
),
0
)
heatmap
[
i
,
j
]
=
dr
[
border
:
-
border
,
border
:
-
border
].
copy
()
heatmap
[
i
,
j
]
*=
origin_max
/
np
.
max
(
heatmap
[
i
,
j
])
return
heatmap
def
dark_parse
(
self
,
hm
,
coord
):
heatmap_height
=
hm
.
shape
[
0
]
heatmap_width
=
hm
.
shape
[
1
]
px
=
int
(
coord
[
0
])
py
=
int
(
coord
[
1
])
if
1
<
px
<
heatmap_width
-
2
and
1
<
py
<
heatmap_height
-
2
:
dx
=
0.5
*
(
hm
[
py
][
px
+
1
]
-
hm
[
py
][
px
-
1
])
dy
=
0.5
*
(
hm
[
py
+
1
][
px
]
-
hm
[
py
-
1
][
px
])
dxx
=
0.25
*
(
hm
[
py
][
px
+
2
]
-
2
*
hm
[
py
][
px
]
+
hm
[
py
][
px
-
2
])
dxy
=
0.25
*
(
hm
[
py
+
1
][
px
+
1
]
-
hm
[
py
-
1
][
px
+
1
]
-
hm
[
py
+
1
][
px
-
1
]
\
+
hm
[
py
-
1
][
px
-
1
])
dyy
=
0.25
*
(
hm
[
py
+
2
*
1
][
px
]
-
2
*
hm
[
py
][
px
]
+
hm
[
py
-
2
*
1
][
px
])
derivative
=
np
.
matrix
([[
dx
],
[
dy
]])
hessian
=
np
.
matrix
([[
dxx
,
dxy
],
[
dxy
,
dyy
]])
if
dxx
*
dyy
-
dxy
**
2
!=
0
:
hessianinv
=
hessian
.
I
offset
=
-
hessianinv
*
derivative
offset
=
np
.
squeeze
(
np
.
array
(
offset
.
T
),
axis
=
0
)
coord
+=
offset
return
coord
def
dark_postprocess
(
self
,
hm
,
coords
,
kernelsize
):
"""
refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
"""
hm
=
self
.
gaussian_blur
(
hm
,
kernelsize
)
hm
=
np
.
maximum
(
hm
,
1e-10
)
hm
=
np
.
log
(
hm
)
for
n
in
range
(
coords
.
shape
[
0
]):
for
p
in
range
(
coords
.
shape
[
1
]):
coords
[
n
,
p
]
=
self
.
dark_parse
(
hm
[
n
][
p
],
coords
[
n
][
p
])
return
coords
def
get_final_preds
(
self
,
heatmaps
,
center
,
scale
,
kernelsize
=
3
):
"""the highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
Args:
heatmaps (numpy.ndarray): The predicted heatmaps
center (numpy.ndarray): The boxes center
scale (numpy.ndarray): The scale factor
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
"""
coords
,
maxvals
=
self
.
get_max_preds
(
heatmaps
)
heatmap_height
=
heatmaps
.
shape
[
2
]
heatmap_width
=
heatmaps
.
shape
[
3
]
if
self
.
use_dark
:
coords
=
self
.
dark_postprocess
(
heatmaps
,
coords
,
kernelsize
)
else
:
for
n
in
range
(
coords
.
shape
[
0
]):
for
p
in
range
(
coords
.
shape
[
1
]):
hm
=
heatmaps
[
n
][
p
]
px
=
int
(
math
.
floor
(
coords
[
n
][
p
][
0
]
+
0.5
))
py
=
int
(
math
.
floor
(
coords
[
n
][
p
][
1
]
+
0.5
))
if
1
<
px
<
heatmap_width
-
1
and
1
<
py
<
heatmap_height
-
1
:
diff
=
np
.
array
([
hm
[
py
][
px
+
1
]
-
hm
[
py
][
px
-
1
],
hm
[
py
+
1
][
px
]
-
hm
[
py
-
1
][
px
]])
coords
[
n
][
p
]
+=
np
.
sign
(
diff
)
*
.
25
preds
=
coords
.
copy
()
# Transform back
for
i
in
range
(
coords
.
shape
[
0
]):
preds
[
i
]
=
transform_preds
(
coords
[
i
],
center
[
i
],
scale
[
i
],
[
heatmap_width
,
heatmap_height
])
return
preds
,
maxvals
def
__call__
(
self
,
output
,
center
,
scale
):
preds
,
maxvals
=
self
.
get_final_preds
(
output
,
center
,
scale
)
return
np
.
concatenate
((
preds
,
maxvals
),
axis
=-
1
),
np
.
mean
(
maxvals
,
axis
=
1
)
def
transform_preds
(
coords
,
center
,
scale
,
output_size
):
target_coords
=
np
.
zeros
(
coords
.
shape
)
trans
=
get_affine_transform
(
center
,
scale
*
200
,
0
,
output_size
,
inv
=
1
)
for
p
in
range
(
coords
.
shape
[
0
]):
target_coords
[
p
,
0
:
2
]
=
affine_transform
(
coords
[
p
,
0
:
2
],
trans
)
return
target_coords
def
affine_transform
(
pt
,
t
):
new_pt
=
np
.
array
([
pt
[
0
],
pt
[
1
],
1.
]).
T
new_pt
=
np
.
dot
(
t
,
new_pt
)
return
new_pt
[:
2
]
def
translate_to_ori_images
(
keypoint_result
,
batch_records
):
kpts
=
keypoint_result
[
'keypoint'
]
scores
=
keypoint_result
[
'score'
]
kpts
[...,
0
]
+=
batch_records
[:,
0
:
1
]
kpts
[...,
1
]
+=
batch_records
[:,
1
:
2
]
return
kpts
,
scores
modules/image/keypoint_detection/pp-tinypose/keypoint_preprocess.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
"""
import
cv2
import
numpy
as
np
class
EvalAffine
(
object
):
def
__init__
(
self
,
size
,
stride
=
64
):
super
(
EvalAffine
,
self
).
__init__
()
self
.
size
=
size
self
.
stride
=
stride
def
__call__
(
self
,
image
,
im_info
):
s
=
self
.
size
h
,
w
,
_
=
image
.
shape
trans
,
size_resized
=
get_affine_mat_kernel
(
h
,
w
,
s
,
inv
=
False
)
image_resized
=
cv2
.
warpAffine
(
image
,
trans
,
size_resized
)
return
image_resized
,
im_info
def
get_affine_mat_kernel
(
h
,
w
,
s
,
inv
=
False
):
if
w
<
h
:
w_
=
s
h_
=
int
(
np
.
ceil
((
s
/
w
*
h
)
/
64.
)
*
64
)
scale_w
=
w
scale_h
=
h_
/
w_
*
w
else
:
h_
=
s
w_
=
int
(
np
.
ceil
((
s
/
h
*
w
)
/
64.
)
*
64
)
scale_h
=
h
scale_w
=
w_
/
h_
*
h
center
=
np
.
array
([
np
.
round
(
w
/
2.
),
np
.
round
(
h
/
2.
)])
size_resized
=
(
w_
,
h_
)
trans
=
get_affine_transform
(
center
,
np
.
array
([
scale_w
,
scale_h
]),
0
,
size_resized
,
inv
=
inv
)
return
trans
,
size_resized
def
get_affine_transform
(
center
,
input_size
,
rot
,
output_size
,
shift
=
(
0.
,
0.
),
inv
=
False
):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert
len
(
center
)
==
2
assert
len
(
output_size
)
==
2
assert
len
(
shift
)
==
2
if
not
isinstance
(
input_size
,
(
np
.
ndarray
,
list
)):
input_size
=
np
.
array
([
input_size
,
input_size
],
dtype
=
np
.
float32
)
scale_tmp
=
input_size
shift
=
np
.
array
(
shift
)
src_w
=
scale_tmp
[
0
]
dst_w
=
output_size
[
0
]
dst_h
=
output_size
[
1
]
rot_rad
=
np
.
pi
*
rot
/
180
src_dir
=
rotate_point
([
0.
,
src_w
*
-
0.5
],
rot_rad
)
dst_dir
=
np
.
array
([
0.
,
dst_w
*
-
0.5
])
src
=
np
.
zeros
((
3
,
2
),
dtype
=
np
.
float32
)
src
[
0
,
:]
=
center
+
scale_tmp
*
shift
src
[
1
,
:]
=
center
+
src_dir
+
scale_tmp
*
shift
src
[
2
,
:]
=
_get_3rd_point
(
src
[
0
,
:],
src
[
1
,
:])
dst
=
np
.
zeros
((
3
,
2
),
dtype
=
np
.
float32
)
dst
[
0
,
:]
=
[
dst_w
*
0.5
,
dst_h
*
0.5
]
dst
[
1
,
:]
=
np
.
array
([
dst_w
*
0.5
,
dst_h
*
0.5
])
+
dst_dir
dst
[
2
,
:]
=
_get_3rd_point
(
dst
[
0
,
:],
dst
[
1
,
:])
if
inv
:
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
dst
),
np
.
float32
(
src
))
else
:
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
src
),
np
.
float32
(
dst
))
return
trans
def
get_warp_matrix
(
theta
,
size_input
,
size_dst
,
size_target
):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta
=
np
.
deg2rad
(
theta
)
matrix
=
np
.
zeros
((
2
,
3
),
dtype
=
np
.
float32
)
scale_x
=
size_dst
[
0
]
/
size_target
[
0
]
scale_y
=
size_dst
[
1
]
/
size_target
[
1
]
matrix
[
0
,
0
]
=
np
.
cos
(
theta
)
*
scale_x
matrix
[
0
,
1
]
=
-
np
.
sin
(
theta
)
*
scale_x
matrix
[
0
,
2
]
=
scale_x
*
(
-
0.5
*
size_input
[
0
]
*
np
.
cos
(
theta
)
+
0.5
*
size_input
[
1
]
*
np
.
sin
(
theta
)
+
0.5
*
size_target
[
0
])
matrix
[
1
,
0
]
=
np
.
sin
(
theta
)
*
scale_y
matrix
[
1
,
1
]
=
np
.
cos
(
theta
)
*
scale_y
matrix
[
1
,
2
]
=
scale_y
*
(
-
0.5
*
size_input
[
0
]
*
np
.
sin
(
theta
)
-
0.5
*
size_input
[
1
]
*
np
.
cos
(
theta
)
+
0.5
*
size_target
[
1
])
return
matrix
def
rotate_point
(
pt
,
angle_rad
):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert
len
(
pt
)
==
2
sn
,
cs
=
np
.
sin
(
angle_rad
),
np
.
cos
(
angle_rad
)
new_x
=
pt
[
0
]
*
cs
-
pt
[
1
]
*
sn
new_y
=
pt
[
0
]
*
sn
+
pt
[
1
]
*
cs
rotated_pt
=
[
new_x
,
new_y
]
return
rotated_pt
def
_get_3rd_point
(
a
,
b
):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert
len
(
a
)
==
2
assert
len
(
b
)
==
2
direction
=
a
-
b
third_pt
=
b
+
np
.
array
([
-
direction
[
1
],
direction
[
0
]],
dtype
=
np
.
float32
)
return
third_pt
class
TopDownEvalAffine
(
object
):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def
__init__
(
self
,
trainsize
,
use_udp
=
False
):
self
.
trainsize
=
trainsize
self
.
use_udp
=
use_udp
def
__call__
(
self
,
image
,
im_info
):
rot
=
0
imshape
=
im_info
[
'im_shape'
][::
-
1
]
center
=
im_info
[
'center'
]
if
'center'
in
im_info
else
imshape
/
2.
scale
=
im_info
[
'scale'
]
if
'scale'
in
im_info
else
imshape
if
self
.
use_udp
:
trans
=
get_warp_matrix
(
rot
,
center
*
2.0
,
[
self
.
trainsize
[
0
]
-
1.0
,
self
.
trainsize
[
1
]
-
1.0
],
scale
)
image
=
cv2
.
warpAffine
(
image
,
trans
,
(
int
(
self
.
trainsize
[
0
]),
int
(
self
.
trainsize
[
1
])),
flags
=
cv2
.
INTER_LINEAR
)
else
:
trans
=
get_affine_transform
(
center
,
scale
,
rot
,
self
.
trainsize
)
image
=
cv2
.
warpAffine
(
image
,
trans
,
(
int
(
self
.
trainsize
[
0
]),
int
(
self
.
trainsize
[
1
])),
flags
=
cv2
.
INTER_LINEAR
)
return
image
,
im_info
def
expand_crop
(
images
,
rect
,
expand_ratio
=
0.3
):
imgh
,
imgw
,
c
=
images
.
shape
label
,
conf
,
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
x
)
for
x
in
rect
.
tolist
()]
if
label
!=
0
:
return
None
,
None
,
None
org_rect
=
[
xmin
,
ymin
,
xmax
,
ymax
]
h_half
=
(
ymax
-
ymin
)
*
(
1
+
expand_ratio
)
/
2.
w_half
=
(
xmax
-
xmin
)
*
(
1
+
expand_ratio
)
/
2.
if
h_half
>
w_half
*
4
/
3
:
w_half
=
h_half
*
0.75
center
=
[(
ymin
+
ymax
)
/
2.
,
(
xmin
+
xmax
)
/
2.
]
ymin
=
max
(
0
,
int
(
center
[
0
]
-
h_half
))
ymax
=
min
(
imgh
-
1
,
int
(
center
[
0
]
+
h_half
))
xmin
=
max
(
0
,
int
(
center
[
1
]
-
w_half
))
xmax
=
min
(
imgw
-
1
,
int
(
center
[
1
]
+
w_half
))
return
images
[
ymin
:
ymax
,
xmin
:
xmax
,
:],
[
xmin
,
ymin
,
xmax
,
ymax
],
org_rect
modules/image/keypoint_detection/pp-tinypose/logger.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
functools
import
logging
import
os
import
sys
import
paddle.distributed
as
dist
__all__
=
[
'setup_logger'
]
logger_initialized
=
[]
def
setup_logger
(
name
=
"ppdet"
,
output
=
None
):
"""
Initialize logger and set its verbosity level to INFO.
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger
=
logging
.
getLogger
(
name
)
if
name
in
logger_initialized
:
return
logger
logger
.
setLevel
(
logging
.
INFO
)
logger
.
propagate
=
False
formatter
=
logging
.
Formatter
(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s"
,
datefmt
=
"%m/%d %H:%M:%S"
)
# stdout logging: master only
local_rank
=
dist
.
get_rank
()
if
local_rank
==
0
:
ch
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
ch
.
setLevel
(
logging
.
DEBUG
)
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
ch
)
# file logging: all workers
if
output
is
not
None
:
if
output
.
endswith
(
".txt"
)
or
output
.
endswith
(
".log"
):
filename
=
output
else
:
filename
=
os
.
path
.
join
(
output
,
"log.txt"
)
if
local_rank
>
0
:
filename
=
filename
+
".rank{}"
.
format
(
local_rank
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
))
fh
=
logging
.
FileHandler
(
filename
,
mode
=
'a'
)
fh
.
setLevel
(
logging
.
DEBUG
)
fh
.
setFormatter
(
logging
.
Formatter
())
logger
.
addHandler
(
fh
)
logger_initialized
.
append
(
name
)
return
logger
modules/image/keypoint_detection/pp-tinypose/module.py
0 → 100644
浏览文件 @
e0c027f3
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
json
import
math
import
os
import
time
from
typing
import
Union
import
cv2
import
numpy
as
np
import
paddle
import
yaml
from
det_keypoint_unite_infer
import
predict_with_given_det
from
infer
import
bench_log
from
infer
import
Detector
from
infer
import
get_test_images
from
infer
import
PredictConfig
from
infer
import
print_arguments
from
keypoint_infer
import
KeyPointDetector
from
keypoint_infer
import
PredictConfig_KeyPoint
from
keypoint_postprocess
import
translate_to_ori_images
from
preprocess
import
base64_to_cv2
from
preprocess
import
decode_image
from
visualize
import
visualize_pose
import
paddlehub.vision.transforms
as
T
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.module
import
runnable
from
paddlehub.module.module
import
serving
@
moduleinfo
(
name
=
"pp-tinypose"
,
type
=
"CV/image_editing"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"Openpose_body_estimation is a body pose estimation model based on Realtime Multi-Person 2D Pose
\
Estimation using Part Affinity Fields."
,
version
=
"1.0.0"
)
class
PP_TinyPose
:
"""
PP-TinyPose Model.
Args:
load_checkpoint(str): Checkpoint save path, default is None.
"""
def
__init__
(
self
):
self
.
det_model_dir
=
os
.
path
.
join
(
self
.
directory
,
'model/picodet_s_320_coco_lcnet/'
)
self
.
keypoint_model_dir
=
os
.
path
.
join
(
self
.
directory
,
'model/dark_hrnet_w32_256x192/'
)
self
.
detector
=
Detector
(
self
.
det_model_dir
)
self
.
topdown_keypoint_detector
=
KeyPointDetector
(
self
.
keypoint_model_dir
)
def
predict
(
self
,
img
:
Union
[
str
,
np
.
ndarray
],
save_path
:
str
=
"pp_tinypose_output"
,
visualization
:
bool
=
False
,
use_gpu
=
False
):
if
use_gpu
:
device
=
'GPU'
else
:
device
=
'CPU'
if
self
.
detector
.
device
!=
device
:
self
.
detector
=
Detector
(
self
.
det_model_dir
,
device
=
device
)
self
.
topdown_keypoint_detector
=
KeyPointDetector
(
self
.
keypoint_model_dir
,
device
=
device
)
self
.
visualization
=
visualization
store_res
=
[]
# Decode image in advance in det + pose prediction
image
,
_
=
decode_image
(
img
,
{})
results
=
self
.
detector
.
predict_image
([
image
],
visual
=
False
)
results
=
self
.
detector
.
filter_box
(
results
,
0.5
)
if
results
[
'boxes_num'
]
>
0
:
keypoint_res
=
predict_with_given_det
(
image
,
results
,
self
.
topdown_keypoint_detector
,
1
,
False
)
save_name
=
img
if
isinstance
(
img
,
str
)
else
(
str
(
time
.
time
())
+
'.png'
)
store_res
.
append
(
[
save_name
,
keypoint_res
[
'bbox'
],
[
keypoint_res
[
'keypoint'
][
0
],
keypoint_res
[
'keypoint'
][
1
]]])
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
if
self
.
visualization
:
visualize_pose
(
save_name
,
keypoint_res
,
visual_thresh
=
0.5
,
save_dir
=
save_path
)
return
store_res
@
serving
def
serving_method
(
self
,
images
:
list
,
**
kwargs
):
"""
Run as a service.
"""
images_decode
=
[
base64_to_cv2
(
image
)
for
image
in
images
]
results
=
self
.
predict
(
img
=
images_decode
[
0
],
**
kwargs
)
results
=
json
.
dumps
(
results
)
return
results
@
runnable
def
run_cmd
(
self
,
argvs
:
list
):
"""
Run as a command.
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
"Run the {} module."
.
format
(
self
.
name
),
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
results
=
self
.
predict
(
img
=
args
.
input_path
,
save_path
=
args
.
output_dir
,
visualization
=
args
.
visualization
,
use_gpu
=
args
.
use_gpu
)
return
results
def
add_module_config_arg
(
self
):
"""
Add the command config options.
"""
self
.
arg_config_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'pp_tinypose_output'
,
help
=
"The directory to save output images."
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
bool
,
default
=
True
,
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
action
=
'store_true'
,
help
=
"use GPU or not"
)
def
add_module_input_arg
(
self
):
"""
Add the command input options.
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
modules/image/keypoint_detection/pp-tinypose/preprocess.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
base64
import
cv2
import
numpy
as
np
from
keypoint_preprocess
import
get_affine_transform
def
decode_image
(
im_file
,
im_info
):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if
isinstance
(
im_file
,
str
):
with
open
(
im_file
,
'rb'
)
as
f
:
im_read
=
f
.
read
()
data
=
np
.
frombuffer
(
im_read
,
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
data
,
1
)
# BGR mode, but need RGB mode
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
else
:
im
=
cv2
.
cvtColor
(
im_file
,
cv2
.
COLOR_BGR2RGB
)
im_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
],
dtype
=
np
.
float32
)
im_info
[
'scale_factor'
]
=
np
.
array
([
1.
,
1.
],
dtype
=
np
.
float32
)
return
im
,
im_info
class
Resize
(
object
):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def
__init__
(
self
,
target_size
,
keep_ratio
=
True
,
interp
=
cv2
.
INTER_LINEAR
):
if
isinstance
(
target_size
,
int
):
target_size
=
[
target_size
,
target_size
]
self
.
target_size
=
target_size
self
.
keep_ratio
=
keep_ratio
self
.
interp
=
interp
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert
len
(
self
.
target_size
)
==
2
assert
self
.
target_size
[
0
]
>
0
and
self
.
target_size
[
1
]
>
0
im_channel
=
im
.
shape
[
2
]
im_scale_y
,
im_scale_x
=
self
.
generate_scale
(
im
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
]).
astype
(
'float32'
)
im_info
[
'scale_factor'
]
=
np
.
array
([
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
return
im
,
im_info
def
generate_scale
(
self
,
im
):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape
=
im
.
shape
[:
2
]
im_c
=
im
.
shape
[
2
]
if
self
.
keep_ratio
:
im_size_min
=
np
.
min
(
origin_shape
)
im_size_max
=
np
.
max
(
origin_shape
)
target_size_min
=
np
.
min
(
self
.
target_size
)
target_size_max
=
np
.
max
(
self
.
target_size
)
im_scale
=
float
(
target_size_min
)
/
float
(
im_size_min
)
if
np
.
round
(
im_scale
*
im_size_max
)
>
target_size_max
:
im_scale
=
float
(
target_size_max
)
/
float
(
im_size_max
)
im_scale_x
=
im_scale
im_scale_y
=
im_scale
else
:
resize_h
,
resize_w
=
self
.
target_size
im_scale_y
=
resize_h
/
float
(
origin_shape
[
0
])
im_scale_x
=
resize_w
/
float
(
origin_shape
[
1
])
return
im_scale_y
,
im_scale_x
class
NormalizeImage
(
object
):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def
__init__
(
self
,
mean
,
std
,
is_scale
=
True
):
self
.
mean
=
mean
self
.
std
=
std
self
.
is_scale
=
is_scale
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
if
self
.
is_scale
:
im
=
im
/
255.0
im
-=
mean
im
/=
std
return
im
,
im_info
class
Permute
(
object
):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def
__init__
(
self
,
):
super
(
Permute
,
self
).
__init__
()
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im
=
im
.
transpose
((
2
,
0
,
1
)).
copy
()
return
im
,
im_info
class
PadStride
(
object
):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def
__init__
(
self
,
stride
=
0
):
self
.
coarsest_stride
=
stride
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride
=
self
.
coarsest_stride
if
coarsest_stride
<=
0
:
return
im
,
im_info
im_c
,
im_h
,
im_w
=
im
.
shape
pad_h
=
int
(
np
.
ceil
(
float
(
im_h
)
/
coarsest_stride
)
*
coarsest_stride
)
pad_w
=
int
(
np
.
ceil
(
float
(
im_w
)
/
coarsest_stride
)
*
coarsest_stride
)
padding_im
=
np
.
zeros
((
im_c
,
pad_h
,
pad_w
),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
im
return
padding_im
,
im_info
class
LetterBoxResize
(
object
):
def
__init__
(
self
,
target_size
):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super
(
LetterBoxResize
,
self
).
__init__
()
if
isinstance
(
target_size
,
int
):
target_size
=
[
target_size
,
target_size
]
self
.
target_size
=
target_size
def
letterbox
(
self
,
img
,
height
,
width
,
color
=
(
127.5
,
127.5
,
127.5
)):
# letterbox: resize a rectangular image to a padded rectangular
shape
=
img
.
shape
[:
2
]
# [height, width]
ratio_h
=
float
(
height
)
/
shape
[
0
]
ratio_w
=
float
(
width
)
/
shape
[
1
]
ratio
=
min
(
ratio_h
,
ratio_w
)
new_shape
=
(
round
(
shape
[
1
]
*
ratio
),
round
(
shape
[
0
]
*
ratio
))
# [width, height]
padw
=
(
width
-
new_shape
[
0
])
/
2
padh
=
(
height
-
new_shape
[
1
])
/
2
top
,
bottom
=
round
(
padh
-
0.1
),
round
(
padh
+
0.1
)
left
,
right
=
round
(
padw
-
0.1
),
round
(
padw
+
0.1
)
img
=
cv2
.
resize
(
img
,
new_shape
,
interpolation
=
cv2
.
INTER_AREA
)
# resized, no border
img
=
cv2
.
copyMakeBorder
(
img
,
top
,
bottom
,
left
,
right
,
cv2
.
BORDER_CONSTANT
,
value
=
color
)
# padded rectangular
return
img
,
ratio
,
padw
,
padh
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert
len
(
self
.
target_size
)
==
2
assert
self
.
target_size
[
0
]
>
0
and
self
.
target_size
[
1
]
>
0
height
,
width
=
self
.
target_size
h
,
w
=
im
.
shape
[:
2
]
im
,
ratio
,
padw
,
padh
=
self
.
letterbox
(
im
,
height
=
height
,
width
=
width
)
new_shape
=
[
round
(
h
*
ratio
),
round
(
w
*
ratio
)]
im_info
[
'im_shape'
]
=
np
.
array
(
new_shape
,
dtype
=
np
.
float32
)
im_info
[
'scale_factor'
]
=
np
.
array
([
ratio
,
ratio
],
dtype
=
np
.
float32
)
return
im
,
im_info
class
Pad
(
object
):
def
__init__
(
self
,
size
,
fill_value
=
[
114.0
,
114.0
,
114.0
]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super
(
Pad
,
self
).
__init__
()
if
isinstance
(
size
,
int
):
size
=
[
size
,
size
]
self
.
size
=
size
self
.
fill_value
=
fill_value
def
__call__
(
self
,
im
,
im_info
):
im_h
,
im_w
=
im
.
shape
[:
2
]
h
,
w
=
self
.
size
if
h
==
im_h
and
w
==
im_w
:
im
=
im
.
astype
(
np
.
float32
)
return
im
,
im_info
canvas
=
np
.
ones
((
h
,
w
,
3
),
dtype
=
np
.
float32
)
canvas
*=
np
.
array
(
self
.
fill_value
,
dtype
=
np
.
float32
)
canvas
[
0
:
im_h
,
0
:
im_w
,
:]
=
im
.
astype
(
np
.
float32
)
im
=
canvas
return
im
,
im_info
class
WarpAffine
(
object
):
"""Warp affine the image
"""
def
__init__
(
self
,
keep_res
=
False
,
pad
=
31
,
input_h
=
512
,
input_w
=
512
,
scale
=
0.4
,
shift
=
0.1
):
self
.
keep_res
=
keep_res
self
.
pad
=
pad
self
.
input_h
=
input_h
self
.
input_w
=
input_w
self
.
scale
=
scale
self
.
shift
=
shift
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_RGB2BGR
)
h
,
w
=
img
.
shape
[:
2
]
if
self
.
keep_res
:
input_h
=
(
h
|
self
.
pad
)
+
1
input_w
=
(
w
|
self
.
pad
)
+
1
s
=
np
.
array
([
input_w
,
input_h
],
dtype
=
np
.
float32
)
c
=
np
.
array
([
w
//
2
,
h
//
2
],
dtype
=
np
.
float32
)
else
:
s
=
max
(
h
,
w
)
*
1.0
input_h
,
input_w
=
self
.
input_h
,
self
.
input_w
c
=
np
.
array
([
w
/
2.
,
h
/
2.
],
dtype
=
np
.
float32
)
trans_input
=
get_affine_transform
(
c
,
s
,
0
,
[
input_w
,
input_h
])
img
=
cv2
.
resize
(
img
,
(
w
,
h
))
inp
=
cv2
.
warpAffine
(
img
,
trans_input
,
(
input_w
,
input_h
),
flags
=
cv2
.
INTER_LINEAR
)
return
inp
,
im_info
def
preprocess
(
im
,
preprocess_ops
):
# process image by preprocess_ops
im_info
=
{
'scale_factor'
:
np
.
array
([
1.
,
1.
],
dtype
=
np
.
float32
),
'im_shape'
:
None
,
}
im
,
im_info
=
decode_image
(
im
,
im_info
)
for
operator
in
preprocess_ops
:
im
,
im_info
=
operator
(
im
,
im_info
)
return
im
,
im_info
def
cv2_to_base64
(
image
:
np
.
ndarray
):
data
=
cv2
.
imencode
(
'.jpg'
,
image
)[
1
]
return
base64
.
b64encode
(
data
.
tostring
()).
decode
(
'utf8'
)
def
base64_to_cv2
(
b64str
:
str
):
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
modules/image/keypoint_detection/pp-tinypose/utils.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
time
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--model_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
),
required
=
True
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of image file."
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
None
,
help
=
"Dir of image file, `image_file` has a higher priority."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"batch_size for inference."
)
parser
.
add_argument
(
"--video_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser
.
add_argument
(
"--camera_id"
,
type
=
int
,
default
=-
1
,
help
=
"device id of camera to predict."
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
default
=
0.5
,
help
=
"Threshold of score."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Directory of output visualization files."
)
parser
.
add_argument
(
"--run_mode"
,
type
=
str
,
default
=
'paddle'
,
help
=
"mode of running(paddle/trt_fp32/trt_fp16/trt_int8)"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'cpu'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Deprecated, please use `--device`."
)
parser
.
add_argument
(
"--run_benchmark"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether to predict a image_file repeatedly for benchmark"
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn with CPU."
)
parser
.
add_argument
(
"--enable_mkldnn_bfloat16"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use mkldnn bfloat16 inference with CPU."
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
1
,
help
=
"Num of threads with CPU."
)
parser
.
add_argument
(
"--trt_min_shape"
,
type
=
int
,
default
=
1
,
help
=
"min_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_max_shape"
,
type
=
int
,
default
=
1280
,
help
=
"max_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_opt_shape"
,
type
=
int
,
default
=
640
,
help
=
"opt_shape for TensorRT."
)
parser
.
add_argument
(
"--trt_calib_mode"
,
type
=
bool
,
default
=
False
,
help
=
"If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True."
)
parser
.
add_argument
(
'--save_images'
,
action
=
'store_true'
,
help
=
'Save visualization image results.'
)
parser
.
add_argument
(
'--save_mot_txts'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt).'
)
parser
.
add_argument
(
'--save_mot_txt_per_img'
,
action
=
'store_true'
,
help
=
'Save tracking results (txt) for each image.'
)
parser
.
add_argument
(
'--scaled'
,
type
=
bool
,
default
=
False
,
help
=
"Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector."
)
parser
.
add_argument
(
"--tracker_config"
,
type
=
str
,
default
=
None
,
help
=
(
"tracker donfig"
))
parser
.
add_argument
(
"--reid_model_dir"
,
type
=
str
,
default
=
None
,
help
=
(
"Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."
))
parser
.
add_argument
(
"--reid_batch_size"
,
type
=
int
,
default
=
50
,
help
=
"max batch_size for reid model inference."
)
parser
.
add_argument
(
'--use_dark'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'whether to use darkpose to get better keypoint position predict '
)
parser
.
add_argument
(
"--action_file"
,
type
=
str
,
default
=
None
,
help
=
"Path of input file for action recognition."
)
parser
.
add_argument
(
"--window_size"
,
type
=
int
,
default
=
50
,
help
=
"Temporal size of skeleton feature for action recognition."
)
parser
.
add_argument
(
"--random_pad"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether do random padding for action recognition."
)
parser
.
add_argument
(
"--save_results"
,
type
=
bool
,
default
=
False
,
help
=
"Whether save detection result to file using coco format"
)
return
parser
class
Times
(
object
):
def
__init__
(
self
):
self
.
time
=
0.
# start time
self
.
st
=
0.
# end time
self
.
et
=
0.
def
start
(
self
):
self
.
st
=
time
.
time
()
def
end
(
self
,
repeats
=
1
,
accumulative
=
True
):
self
.
et
=
time
.
time
()
if
accumulative
:
self
.
time
+=
(
self
.
et
-
self
.
st
)
/
repeats
else
:
self
.
time
=
(
self
.
et
-
self
.
st
)
/
repeats
def
reset
(
self
):
self
.
time
=
0.
self
.
st
=
0.
self
.
et
=
0.
def
value
(
self
):
return
round
(
self
.
time
,
4
)
class
Timer
(
Times
):
def
__init__
(
self
,
with_tracker
=
False
):
super
(
Timer
,
self
).
__init__
()
self
.
with_tracker
=
with_tracker
self
.
preprocess_time_s
=
Times
()
self
.
inference_time_s
=
Times
()
self
.
postprocess_time_s
=
Times
()
self
.
tracking_time_s
=
Times
()
self
.
img_num
=
0
def
info
(
self
,
average
=
False
):
pre_time
=
self
.
preprocess_time_s
.
value
()
infer_time
=
self
.
inference_time_s
.
value
()
post_time
=
self
.
postprocess_time_s
.
value
()
track_time
=
self
.
tracking_time_s
.
value
()
total_time
=
pre_time
+
infer_time
+
post_time
if
self
.
with_tracker
:
total_time
=
total_time
+
track_time
total_time
=
round
(
total_time
,
4
)
print
(
"------------------ Inference Time Info ----------------------"
)
print
(
"total_time(ms): {}, img_num: {}"
.
format
(
total_time
*
1000
,
self
.
img_num
))
preprocess_time
=
round
(
pre_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
pre_time
postprocess_time
=
round
(
post_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
post_time
inference_time
=
round
(
infer_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
infer_time
tracking_time
=
round
(
track_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
track_time
average_latency
=
total_time
/
max
(
1
,
self
.
img_num
)
qps
=
0
if
total_time
>
0
:
qps
=
1
/
average_latency
print
(
"average latency time(ms): {:.2f}, QPS: {:2f}"
.
format
(
average_latency
*
1000
,
qps
))
if
self
.
with_tracker
:
print
(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}"
.
format
(
preprocess_time
*
1000
,
inference_time
*
1000
,
postprocess_time
*
1000
,
tracking_time
*
1000
))
else
:
print
(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}"
.
format
(
preprocess_time
*
1000
,
inference_time
*
1000
,
postprocess_time
*
1000
))
def
report
(
self
,
average
=
False
):
dic
=
{}
pre_time
=
self
.
preprocess_time_s
.
value
()
infer_time
=
self
.
inference_time_s
.
value
()
post_time
=
self
.
postprocess_time_s
.
value
()
track_time
=
self
.
tracking_time_s
.
value
()
dic
[
'preprocess_time_s'
]
=
round
(
pre_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
pre_time
dic
[
'inference_time_s'
]
=
round
(
infer_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
infer_time
dic
[
'postprocess_time_s'
]
=
round
(
post_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
post_time
dic
[
'img_num'
]
=
self
.
img_num
total_time
=
pre_time
+
infer_time
+
post_time
if
self
.
with_tracker
:
dic
[
'tracking_time_s'
]
=
round
(
track_time
/
max
(
1
,
self
.
img_num
),
4
)
if
average
else
track_time
total_time
=
total_time
+
track_time
dic
[
'total_time_s'
]
=
round
(
total_time
,
4
)
return
dic
def
get_current_memory_mb
():
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
import
pynvml
import
psutil
import
GPUtil
gpu_id
=
int
(
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
0
))
pid
=
os
.
getpid
()
p
=
psutil
.
Process
(
pid
)
info
=
p
.
memory_full_info
()
cpu_mem
=
info
.
uss
/
1024.
/
1024.
gpu_mem
=
0
gpu_percent
=
0
gpus
=
GPUtil
.
getGPUs
()
if
gpu_id
is
not
None
and
len
(
gpus
)
>
0
:
gpu_percent
=
gpus
[
gpu_id
].
load
pynvml
.
nvmlInit
()
handle
=
pynvml
.
nvmlDeviceGetHandleByIndex
(
0
)
meminfo
=
pynvml
.
nvmlDeviceGetMemoryInfo
(
handle
)
gpu_mem
=
meminfo
.
used
/
1024.
/
1024.
return
round
(
cpu_mem
,
4
),
round
(
gpu_mem
,
4
),
round
(
gpu_percent
,
4
)
modules/image/keypoint_detection/pp-tinypose/visualize.py
0 → 100644
浏览文件 @
e0c027f3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
os
import
cv2
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
ImageDraw
from
PIL
import
ImageFile
ImageFile
.
LOAD_TRUNCATED_IMAGES
=
True
import
math
def
visualize_box
(
im
,
results
,
labels
,
threshold
=
0.5
):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, im_h, im_w]
labels (list): labels:['class1', ..., 'classn']
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
"""
if
isinstance
(
im
,
str
):
im
=
Image
.
open
(
im
).
convert
(
'RGB'
)
elif
isinstance
(
im
,
np
.
ndarray
):
im
=
Image
.
fromarray
(
im
)
if
'boxes'
in
results
and
len
(
results
[
'boxes'
])
>
0
:
im
=
draw_box
(
im
,
results
[
'boxes'
],
labels
,
threshold
=
threshold
)
return
im
def
get_color_map_list
(
num_classes
):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map
=
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
num_classes
):
j
=
0
lab
=
i
while
lab
:
color_map
[
i
*
3
]
|=
(((
lab
>>
0
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
1
]
|=
(((
lab
>>
1
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
2
]
|=
(((
lab
>>
2
)
&
1
)
<<
(
7
-
j
))
j
+=
1
lab
>>=
3
color_map
=
[
color_map
[
i
:
i
+
3
]
for
i
in
range
(
0
,
len
(
color_map
),
3
)]
return
color_map
def
draw_box
(
im
,
np_boxes
,
labels
,
threshold
=
0.5
):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
threshold (float): threshold of box
Returns:
im (PIL.Image.Image): visualized image
"""
draw_thickness
=
min
(
im
.
size
)
//
320
draw
=
ImageDraw
.
Draw
(
im
)
clsid2color
=
{}
color_list
=
get_color_map_list
(
len
(
labels
))
expect_boxes
=
(
np_boxes
[:,
1
]
>
threshold
)
&
(
np_boxes
[:,
0
]
>
-
1
)
np_boxes
=
np_boxes
[
expect_boxes
,
:]
for
dt
in
np_boxes
:
clsid
,
bbox
,
score
=
int
(
dt
[
0
]),
dt
[
2
:],
dt
[
1
]
if
clsid
not
in
clsid2color
:
clsid2color
[
clsid
]
=
color_list
[
clsid
]
color
=
tuple
(
clsid2color
[
clsid
])
if
len
(
bbox
)
==
4
:
xmin
,
ymin
,
xmax
,
ymax
=
bbox
print
(
'class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
'right_bottom:[{:.2f},{:.2f}]'
.
format
(
int
(
clsid
),
score
,
xmin
,
ymin
,
xmax
,
ymax
))
# draw bbox
draw
.
line
([(
xmin
,
ymin
),
(
xmin
,
ymax
),
(
xmax
,
ymax
),
(
xmax
,
ymin
),
(
xmin
,
ymin
)],
width
=
draw_thickness
,
fill
=
color
)
elif
len
(
bbox
)
==
8
:
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
=
bbox
draw
.
line
([(
x1
,
y1
),
(
x2
,
y2
),
(
x3
,
y3
),
(
x4
,
y4
),
(
x1
,
y1
)],
width
=
2
,
fill
=
color
)
xmin
=
min
(
x1
,
x2
,
x3
,
x4
)
ymin
=
min
(
y1
,
y2
,
y3
,
y4
)
# draw label
text
=
"{} {:.4f}"
.
format
(
labels
[
clsid
],
score
)
tw
,
th
=
draw
.
textsize
(
text
)
draw
.
rectangle
([(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
text
((
xmin
+
1
,
ymin
-
th
),
text
,
fill
=
(
255
,
255
,
255
))
return
im
def
get_color
(
idx
):
idx
=
idx
*
3
color
=
((
37
*
idx
)
%
255
,
(
17
*
idx
)
%
255
,
(
29
*
idx
)
%
255
)
return
color
def
visualize_pose
(
imgfile
,
results
,
visual_thresh
=
0.6
,
save_name
=
'pose.jpg'
,
save_dir
=
'output'
,
returnimg
=
False
,
ids
=
None
):
try
:
import
matplotlib.pyplot
as
plt
import
matplotlib
plt
.
switch_backend
(
'agg'
)
except
Exception
as
e
:
raise
e
skeletons
,
scores
=
results
[
'keypoint'
]
skeletons
=
np
.
array
(
skeletons
)
kpt_nums
=
17
if
len
(
skeletons
)
>
0
:
kpt_nums
=
skeletons
.
shape
[
1
]
if
kpt_nums
==
17
:
#plot coco keypoint
EDGES
=
[(
0
,
1
),
(
0
,
2
),
(
1
,
3
),
(
2
,
4
),
(
3
,
5
),
(
4
,
6
),
(
5
,
7
),
(
6
,
8
),
(
7
,
9
),
(
8
,
10
),
(
5
,
11
),
(
6
,
12
),
(
11
,
13
),
(
12
,
14
),
(
13
,
15
),
(
14
,
16
),
(
11
,
12
)]
else
:
#plot mpii keypoint
EDGES
=
[(
0
,
1
),
(
1
,
2
),
(
3
,
4
),
(
4
,
5
),
(
2
,
6
),
(
3
,
6
),
(
6
,
7
),
(
7
,
8
),
(
8
,
9
),
(
10
,
11
),
(
11
,
12
),
(
13
,
14
),
(
14
,
15
),
(
8
,
12
),
(
8
,
13
)]
NUM_EDGES
=
len
(
EDGES
)
colors
=
[[
255
,
0
,
0
],
[
255
,
85
,
0
],
[
255
,
170
,
0
],
[
255
,
255
,
0
],
[
170
,
255
,
0
],
[
85
,
255
,
0
],
[
0
,
255
,
0
],
\
[
0
,
255
,
85
],
[
0
,
255
,
170
],
[
0
,
255
,
255
],
[
0
,
170
,
255
],
[
0
,
85
,
255
],
[
0
,
0
,
255
],
[
85
,
0
,
255
],
\
[
170
,
0
,
255
],
[
255
,
0
,
255
],
[
255
,
0
,
170
],
[
255
,
0
,
85
]]
cmap
=
matplotlib
.
cm
.
get_cmap
(
'hsv'
)
plt
.
figure
()
img
=
cv2
.
imread
(
imgfile
)
if
type
(
imgfile
)
==
str
else
imgfile
color_set
=
results
[
'colors'
]
if
'colors'
in
results
else
None
if
'bbox'
in
results
and
ids
is
None
:
bboxs
=
results
[
'bbox'
]
for
j
,
rect
in
enumerate
(
bboxs
):
xmin
,
ymin
,
xmax
,
ymax
=
rect
color
=
colors
[
0
]
if
color_set
is
None
else
colors
[
color_set
[
j
]
%
len
(
colors
)]
cv2
.
rectangle
(
img
,
(
xmin
,
ymin
),
(
xmax
,
ymax
),
color
,
1
)
canvas
=
img
.
copy
()
for
i
in
range
(
kpt_nums
):
for
j
in
range
(
len
(
skeletons
)):
if
skeletons
[
j
][
i
,
2
]
<
visual_thresh
:
continue
if
ids
is
None
:
color
=
colors
[
i
]
if
color_set
is
None
else
colors
[
color_set
[
j
]
%
len
(
colors
)]
else
:
color
=
get_color
(
ids
[
j
])
cv2
.
circle
(
canvas
,
tuple
(
skeletons
[
j
][
i
,
0
:
2
].
astype
(
'int32'
)),
2
,
color
,
thickness
=-
1
)
to_plot
=
cv2
.
addWeighted
(
img
,
0.3
,
canvas
,
0.7
,
0
)
fig
=
matplotlib
.
pyplot
.
gcf
()
stickwidth
=
2
for
i
in
range
(
NUM_EDGES
):
for
j
in
range
(
len
(
skeletons
)):
edge
=
EDGES
[
i
]
if
skeletons
[
j
][
edge
[
0
],
2
]
<
visual_thresh
or
skeletons
[
j
][
edge
[
1
],
2
]
<
visual_thresh
:
continue
cur_canvas
=
canvas
.
copy
()
X
=
[
skeletons
[
j
][
edge
[
0
],
1
],
skeletons
[
j
][
edge
[
1
],
1
]]
Y
=
[
skeletons
[
j
][
edge
[
0
],
0
],
skeletons
[
j
][
edge
[
1
],
0
]]
mX
=
np
.
mean
(
X
)
mY
=
np
.
mean
(
Y
)
length
=
((
X
[
0
]
-
X
[
1
])
**
2
+
(
Y
[
0
]
-
Y
[
1
])
**
2
)
**
0.5
angle
=
math
.
degrees
(
math
.
atan2
(
X
[
0
]
-
X
[
1
],
Y
[
0
]
-
Y
[
1
]))
polygon
=
cv2
.
ellipse2Poly
((
int
(
mY
),
int
(
mX
)),
(
int
(
length
/
2
),
stickwidth
),
int
(
angle
),
0
,
360
,
1
)
if
ids
is
None
:
color
=
colors
[
i
]
if
color_set
is
None
else
colors
[
color_set
[
j
]
%
len
(
colors
)]
else
:
color
=
get_color
(
ids
[
j
])
cv2
.
fillConvexPoly
(
cur_canvas
,
polygon
,
color
)
canvas
=
cv2
.
addWeighted
(
canvas
,
0.4
,
cur_canvas
,
0.6
,
0
)
if
returnimg
:
return
canvas
save_name
=
os
.
path
.
join
(
save_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
imgfile
))[
0
]
+
'_vis.jpg'
)
plt
.
imsave
(
save_name
,
canvas
[:,
:,
::
-
1
])
print
(
"keypoint visualize image saved to: "
+
save_name
)
plt
.
close
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录