From 5ad5a8199e3b3e1162af50c0a52eca678a259d67 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Fri, 14 May 2021 22:55:11 +0800 Subject: [PATCH] pose bottomup higherhrnet: deploy (#2737) --- configs/keypoint/README.md | 74 ++++ .../higherhrnet/higherhrnet_hrnet_w32_512.yml | 3 +- .../higherhrnet_hrnet_w32_512_swahr.yml | 2 +- .../higherhrnet/higherhrnet_hrnet_w32_640.yml | 134 ++++++ ...coco_256x192.yml => hrnet_w32_256x192.yml} | 2 +- configs/keypoint/hrnet/hrnet_w32_384x288.yml | 144 ++++++ deploy/python/infer.py | 32 +- deploy/python/keypoint_det_unite_infer.py | 195 ++++++++ deploy/python/keypoint_infer.py | 415 ++++++++++++++++++ deploy/python/keypoint_postprocess.py | 302 +++++++++++++ deploy/python/keypoint_preprocess.py | 178 ++++++++ deploy/python/keypoint_visualize.py | 106 +++++ deploy/python/topdown_unite_utils.py | 111 +++++ ppdet/optimizer.py | 2 +- 14 files changed, 1686 insertions(+), 14 deletions(-) create mode 100644 configs/keypoint/README.md create mode 100644 configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml rename configs/keypoint/hrnet/{hrnet_coco_256x192.yml => hrnet_w32_256x192.yml} (98%) create mode 100644 configs/keypoint/hrnet/hrnet_w32_384x288.yml create mode 100644 deploy/python/keypoint_det_unite_infer.py create mode 100644 deploy/python/keypoint_infer.py create mode 100644 deploy/python/keypoint_postprocess.py create mode 100644 deploy/python/keypoint_preprocess.py create mode 100644 deploy/python/keypoint_visualize.py create mode 100644 deploy/python/topdown_unite_utils.py diff --git a/configs/keypoint/README.md b/configs/keypoint/README.md new file mode 100644 index 000000000..d0d64b014 --- /dev/null +++ b/configs/keypoint/README.md @@ -0,0 +1,74 @@ +# KeyPoint模型系列 + + + +## 简介 + +- PaddleDetection KeyPoint部分紧跟业内最新最优算法方案,包含Top-Down、BottomUp两套方案,以满足用户的不同需求。 + + + +#### Model Zoo + +| 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 | +| :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ------------------------------------------------------------ | +| HigherHRNet | 512 | 32 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml) | +| HigherHRNet | 640 | 32 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml) | +| HigherHRNet+SWAHR | 512 | 32 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) | +| HRNet | 256x192 | 32 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/hrnet/hrnet_w32_256x192.yml) | +| HRNet | 384x288 | 32 | 77.8 | [hrnet_w32_384x288.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/hrnet/hrnet_w32_384x288.yml) | + + + +## 快速开始 + +### 1、环境安装 + +​ 请参考PaddleDetection [安装文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/INSTALL_cn.md)正确安装PaddlePaddle和PaddleDetection即可 + +### 2、数据准备 + +​ 目前KeyPoint模型基于coco数据集开发,其他数据集尚未验证 + +​ 请参考PaddleDetection[数据准备部分](https://github.com/PaddlePaddle/PaddleDetection/blob/f0a30f3ba6095ebfdc8fffb6d02766406afc438a/docs/tutorials/PrepareDataSet.md)部署准备COCO数据集即可 + +### 3、训练与测试 + +​ **单卡训练:** + +```shell +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml +``` + +​ **多卡训练:** + +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m paddle.distributed.launch tools/train.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml +``` + +​ **模型评估:** + +```shell +CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml +``` + +​ **模型预测:** + +```shell +CUDA_VISIBLE_DEVICES=0 python3 tools/infer.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=./output/higherhrnet_hrnet_w32_512/model_final.pdparams --infer_dir=../images/ --draw_threshold=0.5 --save_txt=True +``` + +​ **部署预测:** + +```shell +#导出模型 +python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=output/higherhrnet_hrnet_w32_512/model_final.pdparams + +#部署推理 +#keypoint top-down/bottom-up 单独推理,图片 +python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=../images/xxx.jpeg --use_gpu=True --threshold=0.5 +python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=../images/xxx.jpeg --use_gpu=True --threshold=0.5 + +#keypoint top-down + detector 与检测联合部署推理 +python deploy/python/keypoint_det_unite_infer.py --det_model_dir=output_inference/ppyolo_r50vd_dcn_2x_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file=../video/xxx.mp4 +``` diff --git a/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml b/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml index f4fcdfbea..e79af4ec7 100644 --- a/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml +++ b/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml @@ -57,8 +57,7 @@ LearningRate: OptimizerBuilder: optimizer: type: Adam - regularizer: - + regularizer: None #####data TrainDataset: diff --git a/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml b/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml index a073c472a..599230e57 100644 --- a/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml +++ b/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml @@ -57,7 +57,7 @@ LearningRate: OptimizerBuilder: optimizer: type: Adam - regularizer: + regularizer: None #####data diff --git a/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml b/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml new file mode 100644 index 000000000..a310ce908 --- /dev/null +++ b/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml @@ -0,0 +1,134 @@ +use_gpu: true +log_iter: 10 +save_dir: output +snapshot_epoch: 10 +weights: output/higherhrnet_hrnet_w32_640/model_final +epoch: 300 +num_joints: &num_joints 17 +flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] +input_size: &input_size 640 +hm_size: &hm_size 160 +hm_size_2x: &hm_size_2x 320 +max_people: &max_people 30 +metric: COCO +IouType: keypoints +num_classes: 1 + + +#####model +architecture: HigherHRNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams + +HigherHRNet: + backbone: HRNet + hrhrnet_head: HrHRNetHead + post_process: HrHRNetPostProcess + flip_perm: *flip_perm + eval_flip: true + +HRNet: + width: &width 32 + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +HrHRNetHead: + num_joints: *num_joints + width: *width + loss: HrHRNetLoss + swahr: false + +HrHRNetLoss: + num_joints: *num_joints + swahr: false + + +#####optimizer +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + milestones: [200, 260] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: None + +#####data +TrainDataset: + !KeypointBottomUpCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + +EvalDataset: + !KeypointBottomUpCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + test_mode: true + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 0 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomAffine: + max_degree: 30 + scale: [0.75, 1.5] + max_shift: 0.2 + trainsize: *input_size + hmsize: [*hm_size, *hm_size_2x] + - KeyPointFlip: + flip_prob: 0.5 + flip_permutation: *flip_perm + hmsize: [*hm_size, *hm_size_2x] + - ToHeatmaps: + num_joints: *num_joints + hmsize: [*hm_size, *hm_size_2x] + sigma: 2 + - TagGenerate: + num_joints: *num_joints + max_people: *max_people + - NormalizePermute: + mean: *global_mean + std: *global_std + batch_size: 20 + shuffle: true + drop_last: true + use_shared_memory: true + +EvalReader: + sample_transforms: + - EvalAffine: + size: *input_size + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + drop_empty: false + +TestReader: + sample_transforms: + - Decode: {} + - EvalAffine: + size: *input_size + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/configs/keypoint/hrnet/hrnet_coco_256x192.yml b/configs/keypoint/hrnet/hrnet_w32_256x192.yml similarity index 98% rename from configs/keypoint/hrnet/hrnet_coco_256x192.yml rename to configs/keypoint/hrnet/hrnet_w32_256x192.yml index 3ecc8461c..78661f2c4 100644 --- a/configs/keypoint/hrnet/hrnet_coco_256x192.yml +++ b/configs/keypoint/hrnet/hrnet_w32_256x192.yml @@ -2,7 +2,7 @@ use_gpu: true log_iter: 5 save_dir: output snapshot_epoch: 10 -weights: output/hrnet_coco_256x192/model_final +weights: output/hrnet_w32_256x192/model_final epoch: 210 num_joints: &num_joints 17 pixel_std: &pixel_std 200 diff --git a/configs/keypoint/hrnet/hrnet_w32_384x288.yml b/configs/keypoint/hrnet/hrnet_w32_384x288.yml new file mode 100644 index 000000000..b6c285856 --- /dev/null +++ b/configs/keypoint/hrnet/hrnet_w32_384x288.yml @@ -0,0 +1,144 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/hrnet_w32_384x288/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 384 +train_width: &train_width 288 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [72, 96] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams + +TopDownHRNet: + backbone: HRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 32 + loss: KeyPointMSELoss + flip: true + +HRNet: + width: *width + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + + +#####optimizer +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + bbox_file: person_detection_results/COCO_val2017_detections_AP_H_56_person.json + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.5 + rot: 40 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + drop_empty: false + +TestReader: + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/deploy/python/infer.py b/deploy/python/infer.py index ea94ac646..a86beb809 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -65,7 +65,9 @@ class Detector(object): trt_min_shape=1, trt_max_shape=1280, trt_opt_shape=640, - trt_calib_mode=False): + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): self.pred_config = pred_config self.predictor = load_predictor( model_dir, @@ -76,7 +78,9 @@ class Detector(object): trt_min_shape=trt_min_shape, trt_max_shape=trt_max_shape, trt_opt_shape=trt_opt_shape, - trt_calib_mode=trt_calib_mode) + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) self.det_times = Timer() self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 @@ -182,7 +186,9 @@ class DetectorSOLOv2(Detector): trt_min_shape=1, trt_max_shape=1280, trt_opt_shape=640, - trt_calib_mode=False): + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): self.pred_config = pred_config self.predictor = load_predictor( model_dir, @@ -193,7 +199,9 @@ class DetectorSOLOv2(Detector): trt_min_shape=trt_min_shape, trt_max_shape=trt_max_shape, trt_opt_shape=trt_opt_shape, - trt_calib_mode=trt_calib_mode) + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) self.det_times = Timer() def predict(self, image, threshold=0.5, warmup=0, repeats=1): @@ -309,7 +317,9 @@ def load_predictor(model_dir, trt_min_shape=1, trt_max_shape=1280, trt_opt_shape=640, - trt_calib_mode=False): + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): """set AnalysisConfig, generate AnalysisPredictor Args: model_dir (str): root path of __model__ and __params__ @@ -345,8 +355,8 @@ def load_predictor(model_dir, config.switch_ir_optim(True) else: config.disable_gpu() - config.set_cpu_math_library_num_threads(FLAGS.cpu_threads) - if FLAGS.enable_mkldnn: + config.set_cpu_math_library_num_threads(cpu_threads) + if enable_mkldnn: try: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10) @@ -502,7 +512,9 @@ def main(): trt_min_shape=FLAGS.trt_min_shape, trt_max_shape=FLAGS.trt_max_shape, trt_opt_shape=FLAGS.trt_opt_shape, - trt_calib_mode=FLAGS.trt_calib_mode) + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) if pred_config.arch == 'SOLOv2': detector = DetectorSOLOv2( pred_config, @@ -513,7 +525,9 @@ def main(): trt_min_shape=FLAGS.trt_min_shape, trt_max_shape=FLAGS.trt_max_shape, trt_opt_shape=FLAGS.trt_opt_shape, - trt_calib_mode=FLAGS.trt_calib_mode) + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/keypoint_det_unite_infer.py b/deploy/python/keypoint_det_unite_infer.py new file mode 100644 index 000000000..06cbaf077 --- /dev/null +++ b/deploy/python/keypoint_det_unite_infer.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from PIL import Image +import cv2 +import numpy as np +import paddle + +from topdown_unite_utils import argsparser +from preprocess import decode_image +from infer import Detector, PredictConfig, print_arguments, get_test_images +from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint +from keypoint_visualize import draw_pose + + +def expand_crop(images, rect, expand_ratio=0.5): + imgh, imgw, c = images.shape + label, _, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()] + if label != 0: + return None, None + h_half = (ymax - ymin) * (1 + expand_ratio) / 2. + w_half = (xmax - xmin) * (1 + expand_ratio) / 2. + center = [(ymin + ymax) / 2., (xmin + xmax) / 2.] + ymin = max(0, int(center[0] - h_half)) + ymax = min(imgh - 1, int(center[0] + h_half)) + xmin = max(0, int(center[1] - w_half)) + xmax = min(imgw - 1, int(center[1] + w_half)) + return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax] + + +def get_person_from_rect(images, results): + det_results = results['boxes'] + mask = det_results[:, 1] > FLAGS.det_threshold + valid_rects = det_results[mask] + image_buff = [] + for rect in valid_rects: + rect_image, new_rect = expand_crop(images, rect) + if rect_image is None: + continue + image_buff.append([rect_image, new_rect]) + return image_buff + + +def affine_backto_orgimages(keypoint_result, batch_records): + kpts, scores = keypoint_result['keypoint'] + kpts[..., 0] += batch_records[0] + kpts[..., 1] += batch_records[1] + return kpts, scores + + +def topdown_unite_predict(detector, topdown_keypoint_detector, image_list): + for i, img_file in enumerate(image_list): + image, _ = decode_image(img_file, {}) + results = detector.predict(image, FLAGS.det_threshold) + batchs_images = get_person_from_rect(image, results) + keypoint_vector = [] + score_vector = [] + rect_vecotr = [] + for batch_images, batch_records in batchs_images: + keypoint_result = topdown_keypoint_detector.predict( + batch_images, FLAGS.keypoint_threshold) + orgkeypoints, scores = affine_backto_orgimages(keypoint_result, + batch_records) + keypoint_vector.append(orgkeypoints) + score_vector.append(scores) + rect_vecotr.append(batch_records) + keypoint_res = {} + keypoint_res['keypoint'] = [ + np.vstack(keypoint_vector), np.vstack(score_vector) + ] + keypoint_res['bbox'] = rect_vecotr + draw_pose( + img_file, keypoint_res, visual_thread=FLAGS.keypoint_threshold) + + +def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id): + if camera_id != -1: + capture = cv2.VideoCapture(camera_id) + video_name = 'output.mp4' + else: + capture = cv2.VideoCapture(FLAGS.video_file) + video_name = os.path.basename( + os.path.split(FLAGS.video_file + '.mp4')[-1]) + fps = 30 + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # yapf: disable + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # yapf: enable + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + out_path = os.path.join(FLAGS.output_dir, video_name) + writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) + index = 1 + while (1): + ret, frame = capture.read() + if not ret: + break + print('detect frame:%d' % (index)) + index += 1 + + frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + results = detector.predict(frame2, FLAGS.det_threshold) + batchs_images = get_person_from_rect(frame, results) + keypoint_vector = [] + score_vector = [] + rect_vecotr = [] + for batch_images, batch_records in batchs_images: + keypoint_result = topdown_keypoint_detector.predict( + batch_images, FLAGS.keypoint_threshold) + orgkeypoints, scores = affine_backto_orgimages(keypoint_result, + batch_records) + keypoint_vector.append(orgkeypoints) + score_vector.append(scores) + rect_vecotr.append(batch_records) + keypoint_res = {} + keypoint_res['keypoint'] = [ + np.vstack(keypoint_vector), np.vstack(score_vector) + ] + keypoint_res['bbox'] = rect_vecotr + im = draw_pose( + frame, + keypoint_res, + visual_thread=FLAGS.keypoint_threshold, + returnimg=True) + + writer.write(im) + if camera_id != -1: + cv2.imshow('Mask Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + writer.release() + + +def main(): + pred_config = PredictConfig(FLAGS.det_model_dir) + detector = Detector( + pred_config, + FLAGS.det_model_dir, + use_gpu=FLAGS.use_gpu, + run_mode=FLAGS.run_mode, + use_dynamic_shape=FLAGS.use_dynamic_shape, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) + + pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir) + topdown_keypoint_detector = KeyPoint_Detector( + pred_config, + FLAGS.keypoint_model_dir, + use_gpu=FLAGS.use_gpu, + run_mode=FLAGS.run_mode, + use_dynamic_shape=FLAGS.use_dynamic_shape, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) + + # predict from video file or camera video stream + if FLAGS.video_file is not None or FLAGS.camera_id != -1: + topdown_unite_predict_video(detector, topdown_keypoint_detector, + FLAGS.camera_id) + else: + # predict from image + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + topdown_unite_predict(detector, topdown_keypoint_detector, img_list) + detector.det_times.info(average=True) + topdown_keypoint_detector.det_times.info(average=True) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + + main() diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py new file mode 100644 index 000000000..5d56a6675 --- /dev/null +++ b/deploy/python/keypoint_infer.py @@ -0,0 +1,415 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import yaml +import glob +from functools import reduce + +from PIL import Image +import cv2 +import numpy as np +import paddle +from preprocess import preprocess, NormalizeImage, Permute +from keypoint_preprocess import EvalAffine, TopDownEvalAffine +from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess +from keypoint_visualize import draw_pose +from paddle.inference import Config +from paddle.inference import create_predictor +from utils import argsparser, Timer, get_current_memory_mb, LoggerHelper +from infer import get_test_images, print_arguments + +# Global dictionary +KEYPOINT_SUPPORT_MODELS = { + 'HigherHRNet': 'keypoint_bottomup', + 'HRNet': 'keypoint_topdown' +} + + +class KeyPoint_Detector(object): + """ + Args: + config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + use_gpu (bool): whether use gpu + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) + threshold (float): threshold to reserve the result for output. + """ + + def __init__(self, + pred_config, + model_dir, + use_gpu=False, + run_mode='fluid', + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + self.pred_config = pred_config + self.predictor = load_predictor( + model_dir, + run_mode=run_mode, + min_subgraph_size=self.pred_config.min_subgraph_size, + use_gpu=use_gpu, + use_dynamic_shape=use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + + def preprocess(self, im): + preprocess_ops = [] + for op_info in self.pred_config.preprocess_infos: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + preprocess_ops.append(eval(op_type)(**new_op_info)) + im, im_info = preprocess(im, preprocess_ops, + self.pred_config.input_shape) + inputs = create_inputs(im, im_info) + return inputs + + def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5): + # postprocess output of predictor + if KEYPOINT_SUPPORT_MODELS[ + self.pred_config.arch] == 'keypoint_bottomup': + results = {} + h, w = inputs['im_shape'][0] + preds = [np_boxes] + if np_masks is not None: + preds += np_masks + preds += [h, w] + keypoint_postprocess = HrHRNetPostProcess() + results['keypoint'] = keypoint_postprocess(*preds) + return results + elif KEYPOINT_SUPPORT_MODELS[ + self.pred_config.arch] == 'keypoint_topdown': + results = {} + imshape = inputs['im_shape'][:, ::-1] + center = np.round(imshape / 2.) + scale = imshape / 200. + keypoint_postprocess = HRNetPostProcess() + results['keypoint'] = keypoint_postprocess(np_boxes, center, scale) + return results + else: + raise ValueError("Unsupported arch: {}, expect {}".format( + self.pred_config.arch, KEYPOINT_SUPPORT_MODELS)) + + def predict(self, image, threshold=0.5, warmup=0, repeats=1): + ''' + Args: + image (str/np.ndarray): path of image/ np.ndarray read by cv2 + threshold (float): threshold of predicted box' score + Returns: + results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + MaskRCNN's results include 'masks': np.ndarray: + shape: [N, im_h, im_w] + ''' + self.det_times.preprocess_time.start() + inputs = self.preprocess(image) + np_boxes, np_masks = None, None + input_names = self.predictor.get_input_names() + + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + self.det_times.preprocess_time.end() + for i in range(warmup): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + if self.pred_config.tagmap: + masks_tensor = self.predictor.get_output_handle(output_names[1]) + heat_k = self.predictor.get_output_handle(output_names[2]) + inds_k = self.predictor.get_output_handle(output_names[3]) + np_masks = [ + masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(), + inds_k.copy_to_cpu() + ] + + self.det_times.inference_time.start() + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + if self.pred_config.tagmap: + masks_tensor = self.predictor.get_output_handle(output_names[1]) + heat_k = self.predictor.get_output_handle(output_names[2]) + inds_k = self.predictor.get_output_handle(output_names[3]) + np_masks = [ + masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(), + inds_k.copy_to_cpu() + ] + self.det_times.inference_time.end(repeats=repeats) + + self.det_times.postprocess_time.start() + results = self.postprocess( + np_boxes, np_masks, inputs, threshold=threshold) + self.det_times.postprocess_time.end() + self.det_times.img_num += 1 + return results + + +def create_inputs(im, im_info): + """generate input for different model type + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + model_arch (str): model type + Returns: + inputs (dict): input of model + """ + inputs = {} + inputs['image'] = np.array((im, )).astype('float32') + inputs['im_shape'] = np.array((im_info['im_shape'], )).astype('float32') + + return inputs + + +class PredictConfig_KeyPoint(): + """set config of preprocess, postprocess and visualize + Args: + model_dir (str): root path of model.yml + """ + + def __init__(self, model_dir): + # parsing Yaml config for Preprocess + deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + self.check_model(yml_conf) + self.arch = yml_conf['arch'] + self.archcls = KEYPOINT_SUPPORT_MODELS[yml_conf['arch']] + self.preprocess_infos = yml_conf['Preprocess'] + self.min_subgraph_size = yml_conf['min_subgraph_size'] + self.labels = yml_conf['label_list'] + self.tagmap = False + if 'keypoint_bottomup' == self.archcls: + self.tagmap = True + self.input_shape = yml_conf['image_shape'] + self.print_config() + + def check_model(self, yml_conf): + """ + Raises: + ValueError: loaded model not in supported model type + """ + for support_model in KEYPOINT_SUPPORT_MODELS: + if support_model in yml_conf['arch']: + return True + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], KEYPOINT_SUPPORT_MODELS)) + + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + + +def load_predictor(model_dir, + run_mode='fluid', + batch_size=1, + use_gpu=False, + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + use_gpu (bool): whether use gpu + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8) + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need use_gpu == True. + """ + if not use_gpu and not run_mode == 'fluid': + raise ValueError( + "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}" + .format(run_mode, use_gpu)) + config = Config( + os.path.join(model_dir, 'model.pdmodel'), + os.path.join(model_dir, 'model.pdiparams')) + precision_map = { + 'trt_int8': Config.Precision.Int8, + 'trt_fp32': Config.Precision.Float32, + 'trt_fp16': Config.Precision.Half + } + if use_gpu: + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(cpu_threads) + if enable_mkldnn: + try: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + except Exception as e: + print( + "The current environment does not support `mkldnn`, so disable mkldnn." + ) + pass + + if run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=1 << 10, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[run_mode], + use_static=False, + use_calib_mode=trt_calib_mode) + + if use_dynamic_shape: + min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]} + max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]} + opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]} + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + print('trt set dynamic shape done!') + + # disable print log when predict + config.disable_glog_info() + # enable shared memory + config.enable_memory_optim() + # disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + predictor = create_predictor(config) + return predictor + + +def predict_image(detector, image_list): + for i, img_file in enumerate(image_list): + if FLAGS.run_benchmark: + detector.predict(img_file, FLAGS.threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + detector.cpu_mem += cm + detector.gpu_mem += gm + detector.gpu_util += gu + print('Test iter {}, file name:{}'.format(i, img_file)) + else: + results = detector.predict(img_file, FLAGS.threshold) + draw_pose(img_file, results, visual_thread=FLAGS.threshold) + + +def predict_video(detector, camera_id): + if camera_id != -1: + capture = cv2.VideoCapture(camera_id) + video_name = 'output.mp4' + else: + capture = cv2.VideoCapture(FLAGS.video_file) + video_name = os.path.basename(os.path.split(FLAGS.video_file)[-1]) + fps = 30 + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # yapf: disable + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # yapf: enable + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + out_path = os.path.join(FLAGS.output_dir, video_name + '.mp4') + writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) + index = 1 + while (1): + ret, frame = capture.read() + if not ret: + break + + print('detect frame:%d' % (index)) + index += 1 + results = detector.predict(frame, FLAGS.threshold) + im = draw_pose( + frame, results, visual_thread=FLAGS.threshold, returnimg=True) + writer.write(im) + if camera_id != -1: + cv2.imshow('Mask Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + writer.release() + + +def main(): + pred_config = PredictConfig_KeyPoint(FLAGS.model_dir) + detector = KeyPoint_Detector( + pred_config, + FLAGS.model_dir, + use_gpu=FLAGS.use_gpu, + run_mode=FLAGS.run_mode, + use_dynamic_shape=FLAGS.use_dynamic_shape, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) + + # predict from video file or camera video stream + if FLAGS.video_file is not None or FLAGS.camera_id != -1: + predict_video(detector, FLAGS.camera_id) + else: + # predict from image + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + predict_image(detector, img_list) + if not FLAGS.run_benchmark: + detector.det_times.info(average=True) + else: + mems = { + 'cpu_rss': detector.cpu_mem / len(img_list), + 'gpu_rss': detector.gpu_mem / len(img_list), + 'gpu_util': detector.gpu_util * 100 / len(img_list) + } + det_logger = LoggerHelper( + FLAGS, detector.det_times.report(average=True), mems) + det_logger.report() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + + main() diff --git a/deploy/python/keypoint_postprocess.py b/deploy/python/keypoint_postprocess.py new file mode 100644 index 000000000..a26244d41 --- /dev/null +++ b/deploy/python/keypoint_postprocess.py @@ -0,0 +1,302 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from scipy.optimize import linear_sum_assignment +from collections import abc, defaultdict +import numpy as np +import math +import paddle +import paddle.nn as nn +from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform + + +class HrHRNetPostProcess(object): + ''' + HrHRNet postprocess contain: + 1) get topk keypoints in the output heatmap + 2) sample the tagmap's value corresponding to each of the topk coordinate + 3) match different joints to combine to some people with Hungary algorithm + 4) adjust the coordinate by +-0.25 to decrease error std + 5) salvage missing joints by check positivity of heatmap - tagdiff_norm + Args: + max_num_people (int): max number of people support in postprocess + heat_thresh (float): value of topk below this threshhold will be ignored + tag_thresh (float): coord's value sampled in tagmap below this threshold belong to same people for init + + inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk + original_height, original_width (float): the original image size + ''' + + def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.): + self.max_num_people = max_num_people + self.heat_thresh = heat_thresh + self.tag_thresh = tag_thresh + + def lerp(self, j, y, x, heatmap): + H, W = heatmap.shape[-2:] + left = np.clip(x - 1, 0, W - 1) + right = np.clip(x + 1, 0, W - 1) + up = np.clip(y - 1, 0, H - 1) + down = np.clip(y + 1, 0, H - 1) + offset_y = np.where(heatmap[j, down, x] > heatmap[j, up, x], 0.25, + -0.25) + offset_x = np.where(heatmap[j, y, right] > heatmap[j, y, left], 0.25, + -0.25) + return offset_y + 0.5, offset_x + 0.5 + + def __call__(self, heatmap, tagmap, heat_k, inds_k, original_height, + original_width): + + N, J, H, W = heatmap.shape + assert N == 1, "only support batch size 1" + heatmap = heatmap[0] + tagmap = tagmap[0] + heats = heat_k[0] + inds_np = inds_k[0] + y = inds_np // W + x = inds_np % W + tags = tagmap[np.arange(J)[None, :].repeat(self.max_num_people), + y.flatten(), x.flatten()].reshape(J, -1, tagmap.shape[-1]) + coords = np.stack((y, x), axis=2) + # threshold + mask = heats > self.heat_thresh + # cluster + cluster = defaultdict(lambda: { + 'coords': np.zeros((J, 2), dtype=np.float32), + 'scores': np.zeros(J, dtype=np.float32), + 'tags': [] + }) + for jid, m in enumerate(mask): + num_valid = m.sum() + if num_valid == 0: + continue + valid_inds = np.where(m)[0] + valid_tags = tags[jid, m, :] + if len(cluster) == 0: # initialize + for i in valid_inds: + tag = tags[jid, i] + key = tag[0] + cluster[key]['tags'].append(tag) + cluster[key]['scores'][jid] = heats[jid, i] + cluster[key]['coords'][jid] = coords[jid, i] + continue + candidates = list(cluster.keys())[:self.max_num_people] + centroids = [ + np.mean( + cluster[k]['tags'], axis=0) for k in candidates + ] + num_clusters = len(centroids) + # shape is (num_valid, num_clusters, tag_dim) + dist = valid_tags[:, None, :] - np.array(centroids)[None, ...] + l2_dist = np.linalg.norm(dist, ord=2, axis=2) + # modulate dist with heat value, see `use_detection_val` + cost = np.round(l2_dist) * 100 - heats[jid, m, None] + # pad the cost matrix, otherwise new pose are ignored + if num_valid > num_clusters: + cost = np.pad(cost, ((0, 0), (0, num_valid - num_clusters)), + constant_values=((0, 0), (0, 1e-10))) + rows, cols = linear_sum_assignment(cost) + for y, x in zip(rows, cols): + tag = tags[jid, y] + if y < num_valid and x < num_clusters and \ + l2_dist[y, x] < self.tag_thresh: + key = candidates[x] # merge to cluster + else: + key = tag[0] # initialize new cluster + cluster[key]['tags'].append(tag) + cluster[key]['scores'][jid] = heats[jid, y] + cluster[key]['coords'][jid] = coords[jid, y] + + # shape is [k, J, 2] and [k, J] + pose_tags = np.array([cluster[k]['tags'] for k in cluster]) + pose_coords = np.array([cluster[k]['coords'] for k in cluster]) + pose_scores = np.array([cluster[k]['scores'] for k in cluster]) + valid = pose_scores > 0 + + pose_kpts = np.zeros((pose_scores.shape[0], J, 3), dtype=np.float32) + if valid.sum() == 0: + return pose_kpts, pose_kpts + + # refine coords + valid_coords = pose_coords[valid].astype(np.int32) + y = valid_coords[..., 0].flatten() + x = valid_coords[..., 1].flatten() + _, j = np.nonzero(valid) + offsets = self.lerp(j, y, x, heatmap) + pose_coords[valid, 0] += offsets[0] + pose_coords[valid, 1] += offsets[1] + + # mean score before salvage + mean_score = pose_scores.mean(axis=1) + pose_kpts[valid, 2] = pose_scores[valid] + + # salvage missing joints + if True: + for pid, coords in enumerate(pose_coords): + tag_mean = np.array(pose_tags[pid]).mean(axis=0) + norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5 + score = heatmap - np.round(norm) # (J, H, W) + flat_score = score.reshape(J, -1) + max_inds = np.argmax(flat_score, axis=1) + max_scores = np.max(flat_score, axis=1) + salvage_joints = (pose_scores[pid] == 0) & (max_scores > 0) + if salvage_joints.sum() == 0: + continue + y = max_inds[salvage_joints] // W + x = max_inds[salvage_joints] % W + offsets = self.lerp(salvage_joints.nonzero()[0], y, x, heatmap) + y = y.astype(np.float32) + offsets[0] + x = x.astype(np.float32) + offsets[1] + pose_coords[pid][salvage_joints, 0] = y + pose_coords[pid][salvage_joints, 1] = x + pose_kpts[pid][salvage_joints, 2] = max_scores[salvage_joints] + pose_kpts[..., :2] = transpred(pose_coords[..., :2][..., ::-1], + original_height, original_width, + min(H, W)) + return pose_kpts, mean_score + + +def transpred(kpts, h, w, s): + trans, _ = get_affine_mat_kernel(h, w, s, inv=True) + + return warp_affine_joints(kpts[..., :2].copy(), trans) + + +def warp_affine_joints(joints, mat): + """Apply affine transformation defined by the transform matrix on the + joints. + + Args: + joints (np.ndarray[..., 2]): Origin coordinate of joints. + mat (np.ndarray[3, 2]): The affine matrix. + + Returns: + matrix (np.ndarray[..., 2]): Result coordinate of joints. + """ + joints = np.array(joints) + shape = joints.shape + joints = joints.reshape(-1, 2) + return np.dot(np.concatenate( + (joints, joints[:, 0:1] * 0 + 1), axis=1), + mat.T).reshape(shape) + + +class HRNetPostProcess(object): + def flip_back(self, output_flipped, matched_parts): + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped + + def get_max_preds(self, heatmaps): + '''get predictions from score maps + + Args: + heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) + + Returns: + preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords + maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints + ''' + assert isinstance(heatmaps, + np.ndarray), 'heatmaps should be numpy.ndarray' + assert heatmaps.ndim == 4, 'batch_images should be 4-ndim' + + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + width = heatmaps.shape[3] + heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1)) + idx = np.argmax(heatmaps_reshaped, 2) + maxvals = np.amax(heatmaps_reshaped, 2) + + maxvals = maxvals.reshape((batch_size, num_joints, 1)) + idx = idx.reshape((batch_size, num_joints, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + + preds[:, :, 0] = (preds[:, :, 0]) % width + preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) + + pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) + pred_mask = pred_mask.astype(np.float32) + + preds *= pred_mask + + return preds, maxvals + + def get_final_preds(self, heatmaps, center, scale): + """the highest heatvalue location with a quarter offset in the + direction from the highest response to the second highest response. + + Args: + heatmaps (numpy.ndarray): The predicted heatmaps + center (numpy.ndarray): The boxes center + scale (numpy.ndarray): The scale factor + + Returns: + preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords + maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints + """ + + coords, maxvals = self.get_max_preds(heatmaps) + + heatmap_height = heatmaps.shape[2] + heatmap_width = heatmaps.shape[3] + + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + hm = heatmaps[n][p] + px = int(math.floor(coords[n][p][0] + 0.5)) + py = int(math.floor(coords[n][p][1] + 0.5)) + if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: + diff = np.array([ + hm[py][px + 1] - hm[py][px - 1], + hm[py + 1][px] - hm[py - 1][px] + ]) + coords[n][p] += np.sign(diff) * .25 + preds = coords.copy() + + # Transform back + for i in range(coords.shape[0]): + preds[i] = transform_preds(coords[i], center[i], scale[i], + [heatmap_width, heatmap_height]) + + return preds, maxvals + + def __call__(self, output, center, scale): + preds, maxvals = self.get_final_preds(output, center, scale) + return np.concatenate( + (preds, maxvals), axis=-1), np.mean( + maxvals, axis=1) + + +def transform_preds(coords, center, scale, output_size): + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] diff --git a/deploy/python/keypoint_preprocess.py b/deploy/python/keypoint_preprocess.py new file mode 100644 index 000000000..345f2d7c2 --- /dev/null +++ b/deploy/python/keypoint_preprocess.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class EvalAffine(object): + def __init__(self, size, stride=64): + super(EvalAffine, self).__init__() + self.size = size + self.stride = stride + + def __call__(self, image, im_info): + s = self.size + h, w, _ = image.shape + trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False) + image_resized = cv2.warpAffine(image, trans, size_resized) + return image_resized, im_info + + +def get_affine_mat_kernel(h, w, s, inv=False): + if w < h: + w_ = s + h_ = int(np.ceil((s / w * h) / 64.) * 64) + scale_w = w + scale_h = h_ / w_ * w + + else: + h_ = s + w_ = int(np.ceil((s / h * w) / 64.) * 64) + scale_h = h + scale_w = w_ / h_ * h + + center = np.array([np.round(w / 2.), np.round(h / 2.)]) + + size_resized = (w_, h_) + trans = get_affine_transform( + center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv) + + return trans, size_resized + + +def get_affine_transform(center, + input_size, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ]): Size of the destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(input_size) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + + scale_tmp = input_size + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + + Returns: + np.ndarray: The 3rd point. + """ + assert len(a) == 2 + assert len(b) == 2 + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + + +class TopDownEvalAffine(object): + """apply affine transform to image and coords + + Args: + trainsize (list): [w, h], the standard size used to train + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the image and coords after tranformed + + """ + + def __init__(self, trainsize): + self.trainsize = trainsize + + def __call__(self, image, im_info): + rot = 0 + imshape = im_info['im_shape'][::-1] + center = im_info['center'] if 'center' in im_info else imshape / 2. + scale = im_info['scale'] if 'scale' in im_info else imshape + trans = get_affine_transform(center, scale, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + + return image, im_info diff --git a/deploy/python/keypoint_visualize.py b/deploy/python/keypoint_visualize.py new file mode 100644 index 000000000..8fc8e1176 --- /dev/null +++ b/deploy/python/keypoint_visualize.py @@ -0,0 +1,106 @@ +# coding: utf-8 +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import os +import numpy as np +import math + + +def map_coco_to_personlab(keypoints): + permute = [0, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3] + return keypoints[:, permute, :] + + +def draw_pose(imgfile, + results, + visual_thread=0.6, + save_name='pose.jpg', + returnimg=False): + try: + import matplotlib.pyplot as plt + import matplotlib + plt.switch_backend('agg') + except Exception as e: + logger.error('Matplotlib not found, please install matplotlib.' + 'for example: `pip install matplotlib`.') + raise e + + EDGES = [(0, 14), (0, 13), (0, 4), (0, 1), (14, 16), (13, 15), (4, 10), + (1, 7), (10, 11), (7, 8), (11, 12), (8, 9), (4, 5), (1, 2), (5, 6), + (2, 3)] + NUM_EDGES = len(EDGES) + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + cmap = matplotlib.cm.get_cmap('hsv') + plt.figure() + + img = cv2.imread(imgfile) if type(imgfile) == str else imgfile + skeletons, scores = results['keypoint'] + + if 'bbox' in results: + bboxs = results['bbox'] + for idx, rect in enumerate(bboxs): + xmin, ymin, xmax, ymax = rect + cv2.rectangle(img, (xmin, ymin), (xmax, ymax), + colors[idx % len(colors)], 2) + + canvas = img.copy() + for i in range(17): + rgba = np.array(cmap(1 - i / 17. - 1. / 34)) + rgba[0:3] *= 255 + for j in range(len(skeletons)): + if skeletons[j][i, 2] < visual_thread: + continue + cv2.circle( + canvas, + tuple(skeletons[j][i, 0:2].astype('int32')), + 2, + colors[i], + thickness=-1) + + to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0) + fig = matplotlib.pyplot.gcf() + + stickwidth = 2 + + skeletons = map_coco_to_personlab(skeletons) + for i in range(NUM_EDGES): + for j in range(len(skeletons)): + edge = EDGES[i] + if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[ + 1], 2] < visual_thread: + continue + + cur_canvas = canvas.copy() + X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]] + Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]] + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), + (int(length / 2), stickwidth), + int(angle), 0, 360, 1) + cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + if returnimg: + return canvas + save_name = 'output/' + os.path.basename(imgfile)[:-4] + '_vis.jpg' + plt.imsave(save_name, canvas[:, :, ::-1]) + print("keypoint visualize image saved to: " + save_name) + plt.close() diff --git a/deploy/python/topdown_unite_utils.py b/deploy/python/topdown_unite_utils.py new file mode 100644 index 000000000..ab483109f --- /dev/null +++ b/deploy/python/topdown_unite_utils.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import argparse + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--det_model_dir", + type=str, + default=None, + help=("Directory include:'model.pdiparams', 'model.pdmodel', " + "'infer_cfg.yml', created by tools/export_model.py."), + required=True) + parser.add_argument( + "--keypoint_model_dir", + type=str, + default=None, + help=("Directory include:'model.pdiparams', 'model.pdmodel', " + "'infer_cfg.yml', created by tools/export_model.py."), + required=True) + parser.add_argument( + "--image_file", type=str, default=None, help="Path of image file.") + parser.add_argument( + "--image_dir", + type=str, + default=None, + help="Dir of image file, `image_file` has a higher priority.") + parser.add_argument( + "--video_file", + type=str, + default=None, + help="Path of video file, `video_file` or `camera_id` has a highest priority." + ) + parser.add_argument( + "--camera_id", + type=int, + default=-1, + help="device id of camera to predict.") + parser.add_argument( + "--det_threshold", type=float, default=0.5, help="Threshold of score.") + parser.add_argument( + "--keypoint_threshold", + type=float, + default=0.5, + help="Threshold of score.") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory of output visualization files.") + parser.add_argument( + "--run_mode", + type=str, + default='fluid', + help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)") + parser.add_argument( + "--use_gpu", + type=ast.literal_eval, + default=False, + help="Whether to predict with GPU.") + parser.add_argument( + "--run_benchmark", + type=ast.literal_eval, + default=False, + help="Whether to predict a image_file repeatedly for benchmark") + parser.add_argument( + "--enable_mkldnn", + type=ast.literal_eval, + default=False, + help="Whether use mkldnn with CPU.") + parser.add_argument( + "--cpu_threads", type=int, default=1, help="Num of threads with CPU.") + parser.add_argument( + "--use_dynamic_shape", + type=ast.literal_eval, + default=False, + help="Dynamic_shape for TensorRT.") + parser.add_argument( + "--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.") + parser.add_argument( + "--trt_max_shape", + type=int, + default=1280, + help="max_shape for TensorRT.") + parser.add_argument( + "--trt_opt_shape", + type=int, + default=640, + help="opt_shape for TensorRT.") + parser.add_argument( + "--trt_calib_mode", + type=bool, + default=False, + help="If the model is produced by TRT offline quantitative " + "calibration, trt_calib_mode need to set True.") + + return parser diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index a8fa02fc4..6b0926488 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -234,7 +234,7 @@ class OptimizerBuilder(): clip_norm=self.clip_grad_by_norm) else: grad_clip = None - if self.regularizer: + if self.regularizer and self.regularizer != 'None': reg_type = self.regularizer['type'] + 'Decay' reg_factor = self.regularizer['factor'] regularization = getattr(regularizer, reg_type)(reg_factor) -- GitLab