From 35e6657241e176863309830ffe6784f84a7365e0 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Wed, 23 Jun 2021 16:40:21 +0800 Subject: [PATCH] add hrnet mpii dataset; (#3460) add dark deploy supported, mpii deploy supported; --- configs/keypoint/README.md | 14 +- .../keypoint/hrnet/dark_hrnet_w32_256x192.yml | 3 - .../keypoint/hrnet/dark_hrnet_w48_256x192.yml | 3 - configs/keypoint/hrnet/hrnet_w32_256x192.yml | 3 - .../keypoint/hrnet/hrnet_w32_256x256_mpii.yml | 130 ++++++++++++++ configs/keypoint/hrnet/hrnet_w32_384x288.yml | 3 - deploy/python/keypoint_det_unite_infer.py | 3 +- deploy/python/keypoint_infer.py | 9 +- deploy/python/keypoint_postprocess.py | 78 ++++++-- deploy/python/keypoint_visualize.py | 17 +- deploy/python/topdown_unite_utils.py | 5 + deploy/python/utils.py | 5 + ppdet/data/source/keypoint_coco.py | 18 +- ppdet/engine/trainer.py | 11 +- ppdet/metrics/keypoint_metrics.py | 167 +++++++++++++++++- 15 files changed, 425 insertions(+), 44 deletions(-) create mode 100644 configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml diff --git a/configs/keypoint/README.md b/configs/keypoint/README.md index 18f485d8f..9912ecc6a 100644 --- a/configs/keypoint/README.md +++ b/configs/keypoint/README.md @@ -13,7 +13,7 @@ #### Model Zoo - +COCO数据集 | 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 | | :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- | | HigherHRNet | 512 | 32 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) | @@ -25,6 +25,12 @@ | HRNet+DarkPose | 384x288 | 32 | 78.3 | [dark_hrnet_w32_384x288.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/dark_hrnet_w32_384x288.pdparams) | [config](./hrnet/dark_hrnet_w32_384x288.yml) | 备注: Top-Down模型测试AP结果基于GroundTruth标注框 +MPII数据集 +| 模型 | 输入尺寸 | 通道数 | PCKh(Mean) | PCKh(Mean@0.1) | 模型下载 | 配置文件 | +| :---- | -------- | ------ | :--------: | :------------: | :----------------------------------------------------------: | -------------------------------------------- | +| HRNet | 256x256 | 32 | 90.6 | 38.5 | [hrnet_w32_256x256_mpii.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x256_mpii.pdparams) | [config](./hrnet/hrnet_w32_256x256_mpii.yml) | + + ## 快速开始 ### 1、环境安装 @@ -74,9 +80,9 @@ python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w #部署推理 #keypoint top-down/bottom-up 单独推理,该模式下top-down模型只支持单人截图预测。 -python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=./demo/000000014439_640x640.jpg --use_gpu=True --threshold=0.5 -python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=./demo/hrnet_demo.jpg --use_gpu=True --threshold=0.5 +python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=./demo/000000014439_640x640.jpg --device=gpu --threshold=0.5 +python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=./demo/hrnet_demo.jpg --device=gpu --threshold=0.5 #keypoint top-down模型 + detector 检测联合部署推理(联合推理只支持top-down方式) -python deploy/python/keypoint_det_unite_infer.py --det_model_dir=output_inference/ppyolo_r50vd_dcn_2x_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file=../video/xxx.mp4 --use_gpu=True +python deploy/python/keypoint_det_unite_infer.py --det_model_dir=output_inference/ppyolo_r50vd_dcn_2x_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file=../video/xxx.mp4 --device=gpu ``` diff --git a/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml b/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml index 5cc59fb2c..b6fa14450 100644 --- a/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml +++ b/configs/keypoint/hrnet/dark_hrnet_w32_256x192.yml @@ -118,9 +118,6 @@ EvalReader: sample_transforms: - TopDownAffine: trainsize: *trainsize - - ToHeatmapsTopDown_DARK: - hmsize: *hmsize - sigma: 2 batch_transforms: - NormalizeImage: mean: *global_mean diff --git a/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml b/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml index 901c0ad83..73ce4877a 100644 --- a/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml +++ b/configs/keypoint/hrnet/dark_hrnet_w48_256x192.yml @@ -118,9 +118,6 @@ EvalReader: sample_transforms: - TopDownAffine: trainsize: *trainsize - - ToHeatmapsTopDown_DARK: - hmsize: *hmsize - sigma: 2 batch_transforms: - NormalizeImage: mean: *global_mean diff --git a/configs/keypoint/hrnet/hrnet_w32_256x192.yml b/configs/keypoint/hrnet/hrnet_w32_256x192.yml index 50a838b9a..206fd0493 100644 --- a/configs/keypoint/hrnet/hrnet_w32_256x192.yml +++ b/configs/keypoint/hrnet/hrnet_w32_256x192.yml @@ -118,9 +118,6 @@ EvalReader: sample_transforms: - TopDownAffine: trainsize: *trainsize - - ToHeatmapsTopDown: - hmsize: *hmsize - sigma: 2 batch_transforms: - NormalizeImage: mean: *global_mean diff --git a/configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml b/configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml new file mode 100644 index 000000000..c29190e08 --- /dev/null +++ b/configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml @@ -0,0 +1,130 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/hrnet_w32_256x256_mpii/model_final +epoch: 210 +num_joints: &num_joints 16 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownMPIIEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 256 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [64, 64] +flip_perm: &flip_perm [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]] + +#####model +architecture: TopDownHRNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/Trunc_HRNet_W32_C_pretrained.pdparams + +TopDownHRNet: + backbone: HRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 32 + loss: KeyPointMSELoss + +HRNet: + width: *width + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + + +#####optimizer +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownMPIIDataset + image_dir: images + anno_path: annotations/mpii_train.json + dataset_dir: dataset/mpii + num_joints: *num_joints + + +EvalDataset: + !KeypointTopDownMPIIDataset + image_dir: images + anno_path: annotations/mpii_val.json + dataset_dir: dataset/mpii + num_joints: *num_joints + + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.5 + rot: 40 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [7, 8, 9, 10, 11, 12, 13, 14, 15] + flip_pairs: *flip_perm + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + +TestReader: + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/configs/keypoint/hrnet/hrnet_w32_384x288.yml b/configs/keypoint/hrnet/hrnet_w32_384x288.yml index 72dc8120f..eb0e87f81 100644 --- a/configs/keypoint/hrnet/hrnet_w32_384x288.yml +++ b/configs/keypoint/hrnet/hrnet_w32_384x288.yml @@ -119,9 +119,6 @@ EvalReader: sample_transforms: - TopDownAffine: trainsize: *trainsize - - ToHeatmapsTopDown: - hmsize: *hmsize - sigma: 2 batch_transforms: - NormalizeImage: mean: *global_mean diff --git a/deploy/python/keypoint_det_unite_infer.py b/deploy/python/keypoint_det_unite_infer.py index 8dcd92c81..a9b0ea69b 100644 --- a/deploy/python/keypoint_det_unite_infer.py +++ b/deploy/python/keypoint_det_unite_infer.py @@ -178,7 +178,8 @@ def main(): trt_opt_shape=FLAGS.trt_opt_shape, trt_calib_mode=FLAGS.trt_calib_mode, cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn) + enable_mkldnn=FLAGS.enable_mkldnn, + use_dark=FLAGS.use_dark) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py index 74bd84a16..981ccd080 100644 --- a/deploy/python/keypoint_infer.py +++ b/deploy/python/keypoint_infer.py @@ -63,7 +63,8 @@ class KeyPoint_Detector(object): trt_opt_shape=640, trt_calib_mode=False, cpu_threads=1, - enable_mkldnn=False): + enable_mkldnn=False, + use_dark=True): self.pred_config = pred_config self.predictor, self.config = load_predictor( model_dir, @@ -79,6 +80,7 @@ class KeyPoint_Detector(object): enable_mkldnn=enable_mkldnn) self.det_times = Timer() self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + self.use_dark = use_dark def preprocess(self, im): preprocess_ops = [] @@ -109,7 +111,7 @@ class KeyPoint_Detector(object): imshape = inputs['im_shape'][:, ::-1] center = np.round(imshape / 2.) scale = imshape / 200. - keypoint_postprocess = HRNetPostProcess() + keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark) results['keypoint'] = keypoint_postprocess(np_boxes, center, scale) return results else: @@ -390,7 +392,8 @@ def main(): trt_opt_shape=FLAGS.trt_opt_shape, trt_calib_mode=FLAGS.trt_calib_mode, cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn) + enable_mkldnn=FLAGS.enable_mkldnn, + use_dark=FLAGS.use_dark) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/keypoint_postprocess.py b/deploy/python/keypoint_postprocess.py index a26244d41..fe51d1ab2 100644 --- a/deploy/python/keypoint_postprocess.py +++ b/deploy/python/keypoint_postprocess.py @@ -14,6 +14,7 @@ from scipy.optimize import linear_sum_assignment from collections import abc, defaultdict +import cv2 import numpy as np import math import paddle @@ -193,6 +194,9 @@ def warp_affine_joints(joints, mat): class HRNetPostProcess(object): + def __init__(self, use_dark=True): + self.use_dark = use_dark + def flip_back(self, output_flipped, matched_parts): assert output_flipped.ndim == 4,\ 'output_flipped should be [batch_size, num_joints, height, width]' @@ -242,7 +246,54 @@ class HRNetPostProcess(object): return preds, maxvals - def get_final_preds(self, heatmaps, center, scale): + def gaussian_blur(self, heatmap, kernel): + border = (kernel - 1) // 2 + batch_size = heatmap.shape[0] + num_joints = heatmap.shape[1] + height = heatmap.shape[2] + width = heatmap.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmap[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border)) + dr[border:-border, border:-border] = heatmap[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmap[i, j] = dr[border:-border, border:-border].copy() + heatmap[i, j] *= origin_max / np.max(heatmap[i, j]) + return heatmap + + def dark_parse(self, hm, coord): + heatmap_height = hm.shape[0] + heatmap_width = hm.shape[1] + px = int(coord[0]) + py = int(coord[1]) + if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2: + dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1]) + dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px]) + dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2]) + dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \ + + hm[py-1][px-1]) + dyy = 0.25 * ( + hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px]) + derivative = np.matrix([[dx], [dy]]) + hessian = np.matrix([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = hessian.I + offset = -hessianinv * derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + + def dark_postprocess(self, hm, coords, kernelsize): + hm = self.gaussian_blur(hm, kernelsize) + hm = np.maximum(hm, 1e-10) + hm = np.log(hm) + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + coords[n, p] = self.dark_parse(hm[n][p], coords[n][p]) + return coords + + def get_final_preds(self, heatmaps, center, scale, kernelsize=3): """the highest heatvalue location with a quarter offset in the direction from the highest response to the second highest response. @@ -261,17 +312,20 @@ class HRNetPostProcess(object): heatmap_height = heatmaps.shape[2] heatmap_width = heatmaps.shape[3] - for n in range(coords.shape[0]): - for p in range(coords.shape[1]): - hm = heatmaps[n][p] - px = int(math.floor(coords[n][p][0] + 0.5)) - py = int(math.floor(coords[n][p][1] + 0.5)) - if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: - diff = np.array([ - hm[py][px + 1] - hm[py][px - 1], - hm[py + 1][px] - hm[py - 1][px] - ]) - coords[n][p] += np.sign(diff) * .25 + if self.use_dark: + coords = self.dark_postprocess(heatmaps, coords, kernelsize) + else: + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + hm = heatmaps[n][p] + px = int(math.floor(coords[n][p][0] + 0.5)) + py = int(math.floor(coords[n][p][1] + 0.5)) + if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: + diff = np.array([ + hm[py][px + 1] - hm[py][px - 1], + hm[py + 1][px] - hm[py - 1][px] + ]) + coords[n][p] += np.sign(diff) * .25 preds = coords.copy() # Transform back diff --git a/deploy/python/keypoint_visualize.py b/deploy/python/keypoint_visualize.py index b379bd702..828207621 100644 --- a/deploy/python/keypoint_visualize.py +++ b/deploy/python/keypoint_visualize.py @@ -34,9 +34,16 @@ def draw_pose(imgfile, 'for example: `pip install matplotlib`.') raise e - EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8), - (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), - (14, 16), (11, 12)] + skeletons, scores = results['keypoint'] + kpt_nums = len(skeletons[0]) + if kpt_nums == 17: #plot coco keypoint + EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8), + (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), + (13, 15), (14, 16), (11, 12)] + else: #plot mpii keypoint + EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8), + (8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12), + (8, 13)] NUM_EDGES = len(EDGES) colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ @@ -46,7 +53,7 @@ def draw_pose(imgfile, plt.figure() img = cv2.imread(imgfile) if type(imgfile) == str else imgfile - skeletons, scores = results['keypoint'] + color_set = results['colors'] if 'colors' in results else None if 'bbox' in results: @@ -58,7 +65,7 @@ def draw_pose(imgfile, cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1) canvas = img.copy() - for i in range(17): + for i in range(kpt_nums): for j in range(len(skeletons)): if skeletons[j][i, 2] < visual_thread: continue diff --git a/deploy/python/topdown_unite_utils.py b/deploy/python/topdown_unite_utils.py index ff34c5e8f..6f7b63df6 100644 --- a/deploy/python/topdown_unite_utils.py +++ b/deploy/python/topdown_unite_utils.py @@ -103,5 +103,10 @@ def argsparser(): default=False, help="If the model is produced by TRT offline quantitative " "calibration, trt_calib_mode need to set True.") + parser.add_argument( + '--use_dark', + type=bool, + default=True, + help='whether to use darkpose to get better keypoint position predict ') return parser diff --git a/deploy/python/utils.py b/deploy/python/utils.py index 75e54b07c..35ad43714 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -108,6 +108,11 @@ def argsparser(): '--save_results', action='store_true', help='Save tracking results (txt).') + parser.add_argument( + '--use_dark', + type=bool, + default=True, + help='whether to use darkpose to get better keypoint position predict ') return parser diff --git a/ppdet/data/source/keypoint_coco.py b/ppdet/data/source/keypoint_coco.py index 471a7de9f..5b7b99ee9 100644 --- a/ppdet/data/source/keypoint_coco.py +++ b/ppdet/data/source/keypoint_coco.py @@ -25,7 +25,8 @@ from ppdet.core.workspace import register, serializable @serializable class KeypointBottomUpBaseDataset(DetDataset): - """Base class for bottom-up datasets. + """Base class for bottom-up datasets. Adapted from + https://github.com/open-mmlab/mmpose All datasets should subclass it. All subclasses should overwrite: @@ -86,7 +87,8 @@ class KeypointBottomUpBaseDataset(DetDataset): @register @serializable class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): - """COCO dataset for bottom-up pose estimation. + """COCO dataset for bottom-up pose estimation. Adapted from + https://github.com/open-mmlab/mmpose The dataset loads raw features and apply specified transforms to return a dict containing the image tensors and other information. @@ -253,7 +255,8 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): @register @serializable class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset): - """CrowdPose dataset for bottom-up pose estimation. + """CrowdPose dataset for bottom-up pose estimation. Adapted from + https://github.com/open-mmlab/mmpose The dataset loads raw features and apply specified transforms to return a dict containing the image tensors and other information. @@ -374,7 +377,9 @@ class KeypointTopDownBaseDataset(DetDataset): @register @serializable class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): - """COCO dataset for top-down pose estimation. + """COCO dataset for top-down pose estimation. Adapted from + https://github.com/leoxiaobin/deep-high-resolution-net.pytorch + Copyright (c) Microsoft, under the MIT License. The dataset loads raw features and apply specified transforms to return a dict containing the image tensors and other information. @@ -567,7 +572,9 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): @register @serializable class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): - """MPII dataset for topdown pose estimation. + """MPII dataset for topdown pose estimation. Adapted from + https://github.com/leoxiaobin/deep-high-resolution-net.pytorch + Copyright (c) Microsoft, under the MIT License. The dataset loads raw features and apply specified transforms to return a dict containing the image tensors and other information. @@ -653,4 +660,5 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): 'joints': joints, 'joints_vis': joints_vis }) + print("number length: {}".format(len(gt_db))) self.db = gt_db diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 49b438924..424e94742 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -35,7 +35,7 @@ from ppdet.optimizer import ModelEMA from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result -from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval +from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval from ppdet.metrics import RBoxMetric from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats @@ -234,6 +234,15 @@ class Trainer(object): len(eval_dataset), self.cfg.num_joints, self.cfg.save_dir) ] + elif self.cfg.metric == 'KeyPointTopDownMPIIEval': + eval_dataset = self.cfg['EvalDataset'] + eval_dataset.check_or_download_dataset() + anno_file = eval_dataset.get_anno() + self._metrics = [ + KeyPointTopDownMPIIEval(anno_file, + len(eval_dataset), self.cfg.num_joints, + self.cfg.save_dir) + ] else: logger.warn("Metric not support for metric type {}".format( self.cfg.metric)) diff --git a/ppdet/metrics/keypoint_metrics.py b/ppdet/metrics/keypoint_metrics.py index 9e956bb55..a2bcccaf6 100644 --- a/ppdet/metrics/keypoint_metrics.py +++ b/ppdet/metrics/keypoint_metrics.py @@ -21,11 +21,18 @@ import numpy as np from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from ..modeling.keypoint_utils import oks_nms +from scipy.io import loadmat, savemat -__all__ = ['KeyPointTopDownCOCOEval'] +__all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval'] class KeyPointTopDownCOCOEval(object): + ''' + Adapted from + https://github.com/leoxiaobin/deep-high-resolution-net.pytorch + Copyright (c) Microsoft, under the MIT License. + ''' + def __init__(self, anno_file, num_samples, @@ -200,3 +207,161 @@ class KeyPointTopDownCOCOEval(object): def get_results(self): return self.eval_results + + +class KeyPointTopDownMPIIEval(object): + def __init__(self, + anno_file, + num_samples, + num_joints, + output_eval, + oks_thre=0.9): + super(KeyPointTopDownMPIIEval, self).__init__() + self.ann_file = anno_file + self.reset() + + def reset(self): + self.results = [] + self.eval_results = {} + self.idx = 0 + + def update(self, inputs, outputs): + kpts, _ = outputs['keypoint'][0] + + num_images = inputs['image'].shape[0] + results = {} + results['preds'] = kpts[:, :, 0:3] + results['boxes'] = np.zeros((num_images, 6)) + results['boxes'][:, 0:2] = inputs['center'].numpy()[:, 0:2] + results['boxes'][:, 2:4] = inputs['scale'].numpy()[:, 0:2] + results['boxes'][:, 4] = np.prod(inputs['scale'].numpy() * 200, 1) + results['boxes'][:, 5] = np.squeeze(inputs['score'].numpy()) + results['image_path'] = inputs['image_file'] + + self.results.append(results) + + def accumulate(self): + self.eval_results = self.evaluate(self.results) + + def log(self): + for item, value in self.eval_results.items(): + print("{} : {}".format(item, value)) + + def get_results(self): + return self.eval_results + + def evaluate(self, outputs, savepath=None): + """Evaluate PCKh for MPII dataset. Adapted from + https://github.com/leoxiaobin/deep-high-resolution-net.pytorch + Copyright (c) Microsoft, under the MIT License. + + Args: + outputs(list(preds, boxes)): + + * preds (np.ndarray[N,K,3]): The first two dimensions are + coordinates, score is the third dimension of the array. + * boxes (np.ndarray[N,6]): [center[0], center[1], scale[0] + , scale[1],area, score] + + Returns: + dict: PCKh for each joint + """ + + kpts = [] + for output in outputs: + preds = output['preds'] + batch_size = preds.shape[0] + for i in range(batch_size): + kpts.append({'keypoints': preds[i]}) + + preds = np.stack([kpt['keypoints'] for kpt in kpts]) + + # convert 0-based index to 1-based index, + # and get the first two dimensions. + preds = preds[..., :2] + 1.0 + + if savepath is not None: + pred_file = os.path.join(savepath, 'pred.mat') + savemat(pred_file, mdict={'preds': preds}) + + SC_BIAS = 0.6 + threshold = 0.5 + + gt_file = os.path.join( + os.path.dirname(self.ann_file), 'mpii_gt_val.mat') + gt_dict = loadmat(gt_file) + dataset_joints = gt_dict['dataset_joints'] + jnt_missing = gt_dict['jnt_missing'] + pos_gt_src = gt_dict['pos_gt_src'] + headboxes_src = gt_dict['headboxes_src'] + + pos_pred_src = np.transpose(preds, [1, 2, 0]) + + head = np.where(dataset_joints == 'head')[1][0] + lsho = np.where(dataset_joints == 'lsho')[1][0] + lelb = np.where(dataset_joints == 'lelb')[1][0] + lwri = np.where(dataset_joints == 'lwri')[1][0] + lhip = np.where(dataset_joints == 'lhip')[1][0] + lkne = np.where(dataset_joints == 'lkne')[1][0] + lank = np.where(dataset_joints == 'lank')[1][0] + + rsho = np.where(dataset_joints == 'rsho')[1][0] + relb = np.where(dataset_joints == 'relb')[1][0] + rwri = np.where(dataset_joints == 'rwri')[1][0] + rkne = np.where(dataset_joints == 'rkne')[1][0] + rank = np.where(dataset_joints == 'rank')[1][0] + rhip = np.where(dataset_joints == 'rhip')[1][0] + + jnt_visible = 1 - jnt_missing + uv_error = pos_pred_src - pos_gt_src + uv_err = np.linalg.norm(uv_error, axis=1) + headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :] + headsizes = np.linalg.norm(headsizes, axis=0) + headsizes *= SC_BIAS + scale = headsizes * np.ones((len(uv_err), 1), dtype=np.float32) + scaled_uv_err = uv_err / scale + scaled_uv_err = scaled_uv_err * jnt_visible + jnt_count = np.sum(jnt_visible, axis=1) + less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible + PCKh = 100. * np.sum(less_than_threshold, axis=1) / jnt_count + + # save + rng = np.arange(0, 0.5 + 0.01, 0.01) + pckAll = np.zeros((len(rng), 16), dtype=np.float32) + + for r, threshold in enumerate(rng): + less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible + pckAll[r, :] = 100. * np.sum(less_than_threshold, + axis=1) / jnt_count + + PCKh = np.ma.array(PCKh, mask=False) + PCKh.mask[6:8] = True + + jnt_count = np.ma.array(jnt_count, mask=False) + jnt_count.mask[6:8] = True + jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64) + + name_value = [ #noqa + ('Head', PCKh[head]), + ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])), + ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])), + ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])), + ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])), + ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])), + ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])), + ('PCKh', np.sum(PCKh * jnt_ratio)), + ('PCKh@0.1', np.sum(pckAll[11, :] * jnt_ratio)) + ] + name_value = OrderedDict(name_value) + + return name_value + + def _sort_and_unique_bboxes(self, kpts, key='bbox_id'): + """sort kpts and remove the repeated ones.""" + kpts = sorted(kpts, key=lambda x: x[key]) + num = len(kpts) + for i in range(num - 1, 0, -1): + if kpts[i][key] == kpts[i - 1][key]: + del kpts[i] + + return kpts -- GitLab