未验证 提交 5ad5a819 编写于 作者: Z zhiboniu 提交者: GitHub

pose bottomup higherhrnet: deploy (#2737)

上级 c1bd0ac2
# KeyPoint模型系列
## 简介
- PaddleDetection KeyPoint部分紧跟业内最新最优算法方案,包含Top-Down、BottomUp两套方案,以满足用户的不同需求。
#### Model Zoo
| 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 |
| :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ------------------------------------------------------------ |
| HigherHRNet | 512 | 32 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet | 640 | 32 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet+SWAHR | 512 | 32 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
| HRNet | 256x192 | 32 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/hrnet/hrnet_w32_256x192.yml) |
| HRNet | 384x288 | 32 | 77.8 | [hrnet_w32_384x288.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/hrnet/hrnet_w32_384x288.yml) |
## 快速开始
### 1、环境安装
​ 请参考PaddleDetection [安装文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/INSTALL_cn.md)正确安装PaddlePaddle和PaddleDetection即可
### 2、数据准备
​ 目前KeyPoint模型基于coco数据集开发,其他数据集尚未验证
​ 请参考PaddleDetection[数据准备部分](https://github.com/PaddlePaddle/PaddleDetection/blob/f0a30f3ba6095ebfdc8fffb6d02766406afc438a/docs/tutorials/PrepareDataSet.md)部署准备COCO数据集即可
### 3、训练与测试
**单卡训练:**
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml
```
**多卡训练:**
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m paddle.distributed.launch tools/train.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml
```
**模型评估:**
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml
```
**模型预测:**
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=./output/higherhrnet_hrnet_w32_512/model_final.pdparams --infer_dir=../images/ --draw_threshold=0.5 --save_txt=True
```
**部署预测:**
```shell
#导出模型
python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=output/higherhrnet_hrnet_w32_512/model_final.pdparams
#部署推理
#keypoint top-down/bottom-up 单独推理,图片
python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=../images/xxx.jpeg --use_gpu=True --threshold=0.5
python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=../images/xxx.jpeg --use_gpu=True --threshold=0.5
#keypoint top-down + detector 与检测联合部署推理
python deploy/python/keypoint_det_unite_infer.py --det_model_dir=output_inference/ppyolo_r50vd_dcn_2x_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file=../video/xxx.mp4
```
......@@ -57,8 +57,7 @@ LearningRate:
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
regularizer: None
#####data
TrainDataset:
......
......@@ -57,7 +57,7 @@ LearningRate:
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
regularizer: None
#####data
......
use_gpu: true
log_iter: 10
save_dir: output
snapshot_epoch: 10
weights: output/higherhrnet_hrnet_w32_640/model_final
epoch: 300
num_joints: &num_joints 17
flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
input_size: &input_size 640
hm_size: &hm_size 160
hm_size_2x: &hm_size_2x 320
max_people: &max_people 30
metric: COCO
IouType: keypoints
num_classes: 1
#####model
architecture: HigherHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
HigherHRNet:
backbone: HRNet
hrhrnet_head: HrHRNetHead
post_process: HrHRNetPostProcess
flip_perm: *flip_perm
eval_flip: true
HRNet:
width: &width 32
freeze_at: -1
freeze_norm: false
return_idx: [0]
HrHRNetHead:
num_joints: *num_joints
width: *width
loss: HrHRNetLoss
swahr: false
HrHRNetLoss:
num_joints: *num_joints
swahr: false
#####optimizer
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
milestones: [200, 260]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer: None
#####data
TrainDataset:
!KeypointBottomUpCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
EvalDataset:
!KeypointBottomUpCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
test_mode: true
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 0
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomAffine:
max_degree: 30
scale: [0.75, 1.5]
max_shift: 0.2
trainsize: *input_size
hmsize: [*hm_size, *hm_size_2x]
- KeyPointFlip:
flip_prob: 0.5
flip_permutation: *flip_perm
hmsize: [*hm_size, *hm_size_2x]
- ToHeatmaps:
num_joints: *num_joints
hmsize: [*hm_size, *hm_size_2x]
sigma: 2
- TagGenerate:
num_joints: *num_joints
max_people: *max_people
- NormalizePermute:
mean: *global_mean
std: *global_std
batch_size: 20
shuffle: true
drop_last: true
use_shared_memory: true
EvalReader:
sample_transforms:
- EvalAffine:
size: *input_size
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- EvalAffine:
size: *input_size
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
......@@ -2,7 +2,7 @@ use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_coco_256x192/model_final
weights: output/hrnet_w32_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
......
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_w32_384x288/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: KeyPointTopDownCOCOEval
num_classes: 1
train_height: &train_height 384
train_width: &train_width 288
trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [72, 96]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
#####model
architecture: TopDownHRNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams
TopDownHRNet:
backbone: HRNet
post_process: HRNetPostProcess
flip_perm: *flip_perm
num_joints: *num_joints
width: &width 32
loss: KeyPointMSELoss
flip: true
HRNet:
width: *width
freeze_at: -1
freeze_norm: false
return_idx: [0]
KeyPointMSELoss:
use_target_weight: true
#####optimizer
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
milestones: [170, 200]
gamma: 0.1
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!KeypointTopDownCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
EvalDataset:
!KeypointTopDownCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json
num_joints: *num_joints
trainsize: *trainsize
pixel_std: *pixel_std
use_gt_bbox: True
image_thre: 0.0
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- RandomFlipHalfBodyTransform:
scale: 0.5
rot: 40
num_joints_half_body: 8
prob_half_body: 0.3
pixel_std: *pixel_std
trainsize: *trainsize
upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
flip_pairs: *flip_perm
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 64
shuffle: true
drop_last: false
EvalReader:
sample_transforms:
- TopDownAffine:
trainsize: *trainsize
- ToHeatmapsTopDown:
hmsize: *hmsize
sigma: 2
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
......@@ -65,7 +65,9 @@ class Detector(object):
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor = load_predictor(
model_dir,
......@@ -76,7 +78,9 @@ class Detector(object):
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode)
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
......@@ -182,7 +186,9 @@ class DetectorSOLOv2(Detector):
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor = load_predictor(
model_dir,
......@@ -193,7 +199,9 @@ class DetectorSOLOv2(Detector):
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode)
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
......@@ -309,7 +317,9 @@ def load_predictor(model_dir,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
......@@ -345,8 +355,8 @@ def load_predictor(model_dir,
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(FLAGS.cpu_threads)
if FLAGS.enable_mkldnn:
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
......@@ -502,7 +512,9 @@ def main():
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode)
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
if pred_config.arch == 'SOLOv2':
detector = DetectorSOLOv2(
pred_config,
......@@ -513,7 +525,9 @@ def main():
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode)
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from PIL import Image
import cv2
import numpy as np
import paddle
from topdown_unite_utils import argsparser
from preprocess import decode_image
from infer import Detector, PredictConfig, print_arguments, get_test_images
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from keypoint_visualize import draw_pose
def expand_crop(images, rect, expand_ratio=0.5):
imgh, imgw, c = images.shape
label, _, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
if label != 0:
return None, None
h_half = (ymax - ymin) * (1 + expand_ratio) / 2.
w_half = (xmax - xmin) * (1 + expand_ratio) / 2.
center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
ymin = max(0, int(center[0] - h_half))
ymax = min(imgh - 1, int(center[0] + h_half))
xmin = max(0, int(center[1] - w_half))
xmax = min(imgw - 1, int(center[1] + w_half))
return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax]
def get_person_from_rect(images, results):
det_results = results['boxes']
mask = det_results[:, 1] > FLAGS.det_threshold
valid_rects = det_results[mask]
image_buff = []
for rect in valid_rects:
rect_image, new_rect = expand_crop(images, rect)
if rect_image is None:
continue
image_buff.append([rect_image, new_rect])
return image_buff
def affine_backto_orgimages(keypoint_result, batch_records):
kpts, scores = keypoint_result['keypoint']
kpts[..., 0] += batch_records[0]
kpts[..., 1] += batch_records[1]
return kpts, scores
def topdown_unite_predict(detector, topdown_keypoint_detector, image_list):
for i, img_file in enumerate(image_list):
image, _ = decode_image(img_file, {})
results = detector.predict(image, FLAGS.det_threshold)
batchs_images = get_person_from_rect(image, results)
keypoint_vector = []
score_vector = []
rect_vecotr = []
for batch_images, batch_records in batchs_images:
keypoint_result = topdown_keypoint_detector.predict(
batch_images, FLAGS.keypoint_threshold)
orgkeypoints, scores = affine_backto_orgimages(keypoint_result,
batch_records)
keypoint_vector.append(orgkeypoints)
score_vector.append(scores)
rect_vecotr.append(batch_records)
keypoint_res = {}
keypoint_res['keypoint'] = [
np.vstack(keypoint_vector), np.vstack(score_vector)
]
keypoint_res['bbox'] = rect_vecotr
draw_pose(
img_file, keypoint_res, visual_thread=FLAGS.keypoint_threshold)
def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.basename(
os.path.split(FLAGS.video_file + '.mp4')[-1])
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame:%d' % (index))
index += 1
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict(frame2, FLAGS.det_threshold)
batchs_images = get_person_from_rect(frame, results)
keypoint_vector = []
score_vector = []
rect_vecotr = []
for batch_images, batch_records in batchs_images:
keypoint_result = topdown_keypoint_detector.predict(
batch_images, FLAGS.keypoint_threshold)
orgkeypoints, scores = affine_backto_orgimages(keypoint_result,
batch_records)
keypoint_vector.append(orgkeypoints)
score_vector.append(scores)
rect_vecotr.append(batch_records)
keypoint_res = {}
keypoint_res['keypoint'] = [
np.vstack(keypoint_vector), np.vstack(score_vector)
]
keypoint_res['bbox'] = rect_vecotr
im = draw_pose(
frame,
keypoint_res,
visual_thread=FLAGS.keypoint_threshold,
returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main():
pred_config = PredictConfig(FLAGS.det_model_dir)
detector = Detector(
pred_config,
FLAGS.det_model_dir,
use_gpu=FLAGS.use_gpu,
run_mode=FLAGS.run_mode,
use_dynamic_shape=FLAGS.use_dynamic_shape,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir)
topdown_keypoint_detector = KeyPoint_Detector(
pred_config,
FLAGS.keypoint_model_dir,
use_gpu=FLAGS.use_gpu,
run_mode=FLAGS.run_mode,
use_dynamic_shape=FLAGS.use_dynamic_shape,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
topdown_unite_predict_video(detector, topdown_keypoint_detector,
FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
topdown_unite_predict(detector, topdown_keypoint_detector, img_list)
detector.det_times.info(average=True)
topdown_keypoint_detector.det_times.info(average=True)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import yaml
import glob
from functools import reduce
from PIL import Image
import cv2
import numpy as np
import paddle
from preprocess import preprocess, NormalizeImage, Permute
from keypoint_preprocess import EvalAffine, TopDownEvalAffine
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from keypoint_visualize import draw_pose
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb, LoggerHelper
from infer import get_test_images, print_arguments
# Global dictionary
KEYPOINT_SUPPORT_MODELS = {
'HigherHRNet': 'keypoint_bottomup',
'HRNet': 'keypoint_topdown'
}
class KeyPoint_Detector(object):
"""
Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
use_gpu (bool): whether use gpu
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
threshold (float): threshold to reserve the result for output.
"""
def __init__(self,
pred_config,
model_dir,
use_gpu=False,
run_mode='fluid',
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor = load_predictor(
model_dir,
run_mode=run_mode,
min_subgraph_size=self.pred_config.min_subgraph_size,
use_gpu=use_gpu,
use_dynamic_shape=use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
def preprocess(self, im):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
im, im_info = preprocess(im, preprocess_ops,
self.pred_config.input_shape)
inputs = create_inputs(im, im_info)
return inputs
def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5):
# postprocess output of predictor
if KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_bottomup':
results = {}
h, w = inputs['im_shape'][0]
preds = [np_boxes]
if np_masks is not None:
preds += np_masks
preds += [h, w]
keypoint_postprocess = HrHRNetPostProcess()
results['keypoint'] = keypoint_postprocess(*preds)
return results
elif KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_topdown':
results = {}
imshape = inputs['im_shape'][:, ::-1]
center = np.round(imshape / 2.)
scale = imshape / 200.
keypoint_postprocess = HRNetPostProcess()
results['keypoint'] = keypoint_postprocess(np_boxes, center, scale)
return results
else:
raise ValueError("Unsupported arch: {}, expect {}".format(
self.pred_config.arch, KEYPOINT_SUPPORT_MODELS))
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
self.det_times.preprocess_time.start()
inputs = self.preprocess(image)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
self.det_times.preprocess_time.end()
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.pred_config.tagmap:
masks_tensor = self.predictor.get_output_handle(output_names[1])
heat_k = self.predictor.get_output_handle(output_names[2])
inds_k = self.predictor.get_output_handle(output_names[3])
np_masks = [
masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(),
inds_k.copy_to_cpu()
]
self.det_times.inference_time.start()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.pred_config.tagmap:
masks_tensor = self.predictor.get_output_handle(output_names[1])
heat_k = self.predictor.get_output_handle(output_names[2])
inds_k = self.predictor.get_output_handle(output_names[3])
np_masks = [
masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(),
inds_k.copy_to_cpu()
]
self.det_times.inference_time.end(repeats=repeats)
self.det_times.postprocess_time.start()
results = self.postprocess(
np_boxes, np_masks, inputs, threshold=threshold)
self.det_times.postprocess_time.end()
self.det_times.img_num += 1
return results
def create_inputs(im, im_info):
"""generate input for different model type
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
model_arch (str): model type
Returns:
inputs (dict): input of model
"""
inputs = {}
inputs['image'] = np.array((im, )).astype('float32')
inputs['im_shape'] = np.array((im_info['im_shape'], )).astype('float32')
return inputs
class PredictConfig_KeyPoint():
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of model.yml
"""
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.archcls = KEYPOINT_SUPPORT_MODELS[yml_conf['arch']]
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list']
self.tagmap = False
if 'keypoint_bottomup' == self.archcls:
self.tagmap = True
self.input_shape = yml_conf['image_shape']
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in KEYPOINT_SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], KEYPOINT_SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
run_mode='fluid',
batch_size=1,
use_gpu=False,
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
use_gpu (bool): whether use gpu
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need use_gpu == True.
"""
if not use_gpu and not run_mode == 'fluid':
raise ValueError(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.format(run_mode, use_gpu))
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
precision_map = {
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
if use_gpu:
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 10,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]}
max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]}
opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
print('trt set dynamic shape done!')
# disable print log when predict
config.disable_glog_info()
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor
def predict_image(detector, image_list):
for i, img_file in enumerate(image_list):
if FLAGS.run_benchmark:
detector.predict(img_file, FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
results = detector.predict(img_file, FLAGS.threshold)
draw_pose(img_file, results, visual_thread=FLAGS.threshold)
def predict_video(detector, camera_id):
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.basename(os.path.split(FLAGS.video_file)[-1])
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name + '.mp4')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame:%d' % (index))
index += 1
results = detector.predict(frame, FLAGS.threshold)
im = draw_pose(
frame, results, visual_thread=FLAGS.threshold, returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main():
pred_config = PredictConfig_KeyPoint(FLAGS.model_dir)
detector = KeyPoint_Detector(
pred_config,
FLAGS.model_dir,
use_gpu=FLAGS.use_gpu,
run_mode=FLAGS.run_mode,
use_dynamic_shape=FLAGS.use_dynamic_shape,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss': detector.cpu_mem / len(img_list),
'gpu_rss': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
det_logger = LoggerHelper(
FLAGS, detector.det_times.report(average=True), mems)
det_logger.report()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from scipy.optimize import linear_sum_assignment
from collections import abc, defaultdict
import numpy as np
import math
import paddle
import paddle.nn as nn
from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform
class HrHRNetPostProcess(object):
'''
HrHRNet postprocess contain:
1) get topk keypoints in the output heatmap
2) sample the tagmap's value corresponding to each of the topk coordinate
3) match different joints to combine to some people with Hungary algorithm
4) adjust the coordinate by +-0.25 to decrease error std
5) salvage missing joints by check positivity of heatmap - tagdiff_norm
Args:
max_num_people (int): max number of people support in postprocess
heat_thresh (float): value of topk below this threshhold will be ignored
tag_thresh (float): coord's value sampled in tagmap below this threshold belong to same people for init
inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
original_height, original_width (float): the original image size
'''
def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.):
self.max_num_people = max_num_people
self.heat_thresh = heat_thresh
self.tag_thresh = tag_thresh
def lerp(self, j, y, x, heatmap):
H, W = heatmap.shape[-2:]
left = np.clip(x - 1, 0, W - 1)
right = np.clip(x + 1, 0, W - 1)
up = np.clip(y - 1, 0, H - 1)
down = np.clip(y + 1, 0, H - 1)
offset_y = np.where(heatmap[j, down, x] > heatmap[j, up, x], 0.25,
-0.25)
offset_x = np.where(heatmap[j, y, right] > heatmap[j, y, left], 0.25,
-0.25)
return offset_y + 0.5, offset_x + 0.5
def __call__(self, heatmap, tagmap, heat_k, inds_k, original_height,
original_width):
N, J, H, W = heatmap.shape
assert N == 1, "only support batch size 1"
heatmap = heatmap[0]
tagmap = tagmap[0]
heats = heat_k[0]
inds_np = inds_k[0]
y = inds_np // W
x = inds_np % W
tags = tagmap[np.arange(J)[None, :].repeat(self.max_num_people),
y.flatten(), x.flatten()].reshape(J, -1, tagmap.shape[-1])
coords = np.stack((y, x), axis=2)
# threshold
mask = heats > self.heat_thresh
# cluster
cluster = defaultdict(lambda: {
'coords': np.zeros((J, 2), dtype=np.float32),
'scores': np.zeros(J, dtype=np.float32),
'tags': []
})
for jid, m in enumerate(mask):
num_valid = m.sum()
if num_valid == 0:
continue
valid_inds = np.where(m)[0]
valid_tags = tags[jid, m, :]
if len(cluster) == 0: # initialize
for i in valid_inds:
tag = tags[jid, i]
key = tag[0]
cluster[key]['tags'].append(tag)
cluster[key]['scores'][jid] = heats[jid, i]
cluster[key]['coords'][jid] = coords[jid, i]
continue
candidates = list(cluster.keys())[:self.max_num_people]
centroids = [
np.mean(
cluster[k]['tags'], axis=0) for k in candidates
]
num_clusters = len(centroids)
# shape is (num_valid, num_clusters, tag_dim)
dist = valid_tags[:, None, :] - np.array(centroids)[None, ...]
l2_dist = np.linalg.norm(dist, ord=2, axis=2)
# modulate dist with heat value, see `use_detection_val`
cost = np.round(l2_dist) * 100 - heats[jid, m, None]
# pad the cost matrix, otherwise new pose are ignored
if num_valid > num_clusters:
cost = np.pad(cost, ((0, 0), (0, num_valid - num_clusters)),
constant_values=((0, 0), (0, 1e-10)))
rows, cols = linear_sum_assignment(cost)
for y, x in zip(rows, cols):
tag = tags[jid, y]
if y < num_valid and x < num_clusters and \
l2_dist[y, x] < self.tag_thresh:
key = candidates[x] # merge to cluster
else:
key = tag[0] # initialize new cluster
cluster[key]['tags'].append(tag)
cluster[key]['scores'][jid] = heats[jid, y]
cluster[key]['coords'][jid] = coords[jid, y]
# shape is [k, J, 2] and [k, J]
pose_tags = np.array([cluster[k]['tags'] for k in cluster])
pose_coords = np.array([cluster[k]['coords'] for k in cluster])
pose_scores = np.array([cluster[k]['scores'] for k in cluster])
valid = pose_scores > 0
pose_kpts = np.zeros((pose_scores.shape[0], J, 3), dtype=np.float32)
if valid.sum() == 0:
return pose_kpts, pose_kpts
# refine coords
valid_coords = pose_coords[valid].astype(np.int32)
y = valid_coords[..., 0].flatten()
x = valid_coords[..., 1].flatten()
_, j = np.nonzero(valid)
offsets = self.lerp(j, y, x, heatmap)
pose_coords[valid, 0] += offsets[0]
pose_coords[valid, 1] += offsets[1]
# mean score before salvage
mean_score = pose_scores.mean(axis=1)
pose_kpts[valid, 2] = pose_scores[valid]
# salvage missing joints
if True:
for pid, coords in enumerate(pose_coords):
tag_mean = np.array(pose_tags[pid]).mean(axis=0)
norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5
score = heatmap - np.round(norm) # (J, H, W)
flat_score = score.reshape(J, -1)
max_inds = np.argmax(flat_score, axis=1)
max_scores = np.max(flat_score, axis=1)
salvage_joints = (pose_scores[pid] == 0) & (max_scores > 0)
if salvage_joints.sum() == 0:
continue
y = max_inds[salvage_joints] // W
x = max_inds[salvage_joints] % W
offsets = self.lerp(salvage_joints.nonzero()[0], y, x, heatmap)
y = y.astype(np.float32) + offsets[0]
x = x.astype(np.float32) + offsets[1]
pose_coords[pid][salvage_joints, 0] = y
pose_coords[pid][salvage_joints, 1] = x
pose_kpts[pid][salvage_joints, 2] = max_scores[salvage_joints]
pose_kpts[..., :2] = transpred(pose_coords[..., :2][..., ::-1],
original_height, original_width,
min(H, W))
return pose_kpts, mean_score
def transpred(kpts, h, w, s):
trans, _ = get_affine_mat_kernel(h, w, s, inv=True)
return warp_affine_joints(kpts[..., :2].copy(), trans)
def warp_affine_joints(joints, mat):
"""Apply affine transformation defined by the transform matrix on the
joints.
Args:
joints (np.ndarray[..., 2]): Origin coordinate of joints.
mat (np.ndarray[3, 2]): The affine matrix.
Returns:
matrix (np.ndarray[..., 2]): Result coordinate of joints.
"""
joints = np.array(joints)
shape = joints.shape
joints = joints.reshape(-1, 2)
return np.dot(np.concatenate(
(joints, joints[:, 0:1] * 0 + 1), axis=1),
mat.T).reshape(shape)
class HRNetPostProcess(object):
def flip_back(self, output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def get_max_preds(self, heatmaps):
'''get predictions from score maps
Args:
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
'''
assert isinstance(heatmaps,
np.ndarray), 'heatmaps should be numpy.ndarray'
assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = heatmaps.shape[0]
num_joints = heatmaps.shape[1]
width = heatmaps.shape[3]
heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def get_final_preds(self, heatmaps, center, scale):
"""the highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
Args:
heatmaps (numpy.ndarray): The predicted heatmaps
center (numpy.ndarray): The boxes center
scale (numpy.ndarray): The scale factor
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
"""
coords, maxvals = self.get_max_preds(heatmaps)
heatmap_height = heatmaps.shape[2]
heatmap_width = heatmaps.shape[3]
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array([
hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]
])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center[i], scale[i],
[heatmap_width, heatmap_height])
return preds, maxvals
def __call__(self, output, center, scale):
preds, maxvals = self.get_final_preds(output, center, scale)
return np.concatenate(
(preds, maxvals), axis=-1), np.mean(
maxvals, axis=1)
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
class EvalAffine(object):
def __init__(self, size, stride=64):
super(EvalAffine, self).__init__()
self.size = size
self.stride = stride
def __call__(self, image, im_info):
s = self.size
h, w, _ = image.shape
trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
image_resized = cv2.warpAffine(image, trans, size_resized)
return image_resized, im_info
def get_affine_mat_kernel(h, w, s, inv=False):
if w < h:
w_ = s
h_ = int(np.ceil((s / w * h) / 64.) * 64)
scale_w = w
scale_h = h_ / w_ * w
else:
h_ = s
w_ = int(np.ceil((s / h * w) / 64.) * 64)
scale_h = h
scale_w = w_ / h_ * h
center = np.array([np.round(w / 2.), np.round(h / 2.)])
size_resized = (w_, h_)
trans = get_affine_transform(
center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv)
return trans, size_resized
def get_affine_transform(center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(input_size) == 2
assert len(output_size) == 2
assert len(shift) == 2
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize):
self.trainsize = trainsize
def __call__(self, image, im_info):
rot = 0
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.
scale = im_info['scale'] if 'scale' in im_info else imshape
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
return image, im_info
# coding: utf-8
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import os
import numpy as np
import math
def map_coco_to_personlab(keypoints):
permute = [0, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3]
return keypoints[:, permute, :]
def draw_pose(imgfile,
results,
visual_thread=0.6,
save_name='pose.jpg',
returnimg=False):
try:
import matplotlib.pyplot as plt
import matplotlib
plt.switch_backend('agg')
except Exception as e:
logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
EDGES = [(0, 14), (0, 13), (0, 4), (0, 1), (14, 16), (13, 15), (4, 10),
(1, 7), (10, 11), (7, 8), (11, 12), (8, 9), (4, 5), (1, 2), (5, 6),
(2, 3)]
NUM_EDGES = len(EDGES)
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
cmap = matplotlib.cm.get_cmap('hsv')
plt.figure()
img = cv2.imread(imgfile) if type(imgfile) == str else imgfile
skeletons, scores = results['keypoint']
if 'bbox' in results:
bboxs = results['bbox']
for idx, rect in enumerate(bboxs):
xmin, ymin, xmax, ymax = rect
cv2.rectangle(img, (xmin, ymin), (xmax, ymax),
colors[idx % len(colors)], 2)
canvas = img.copy()
for i in range(17):
rgba = np.array(cmap(1 - i / 17. - 1. / 34))
rgba[0:3] *= 255
for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thread:
continue
cv2.circle(
canvas,
tuple(skeletons[j][i, 0:2].astype('int32')),
2,
colors[i],
thickness=-1)
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
fig = matplotlib.pyplot.gcf()
stickwidth = 2
skeletons = map_coco_to_personlab(skeletons)
for i in range(NUM_EDGES):
for j in range(len(skeletons)):
edge = EDGES[i]
if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[
1], 2] < visual_thread:
continue
cur_canvas = canvas.copy()
X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
if returnimg:
return canvas
save_name = 'output/' + os.path.basename(imgfile)[:-4] + '_vis.jpg'
plt.imsave(save_name, canvas[:, :, ::-1])
print("keypoint visualize image saved to: " + save_name)
plt.close()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import argparse
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--det_model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."),
required=True)
parser.add_argument(
"--keypoint_model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."),
required=True)
parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--video_file",
type=str,
default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--det_threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
"--keypoint_threshold",
type=float,
default=0.5,
help="Threshold of score.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory of output visualization files.")
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to predict with GPU.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
"--use_dynamic_shape",
type=ast.literal_eval,
default=False,
help="Dynamic_shape for TensorRT.")
parser.add_argument(
"--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
parser.add_argument(
"--trt_max_shape",
type=int,
default=1280,
help="max_shape for TensorRT.")
parser.add_argument(
"--trt_opt_shape",
type=int,
default=640,
help="opt_shape for TensorRT.")
parser.add_argument(
"--trt_calib_mode",
type=bool,
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
return parser
......@@ -234,7 +234,7 @@ class OptimizerBuilder():
clip_norm=self.clip_grad_by_norm)
else:
grad_clip = None
if self.regularizer:
if self.regularizer and self.regularizer != 'None':
reg_type = self.regularizer['type'] + 'Decay'
reg_factor = self.regularizer['factor']
regularization = getattr(regularizer, reg_type)(reg_factor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册