From a694be1e948dc3437f525e3aedad5bde41183f97 Mon Sep 17 00:00:00 2001 From: Yang Nie Date: Mon, 22 May 2023 13:53:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20+=20No.163=E3=80=91?= =?UTF-8?q?=E5=9F=BA=E4=BA=8EPaddleDetection=20PP-TinyPose=EF=BC=8C?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=89=8B=E5=8A=BF=E5=85=B3=E9=94=AE=E7=82=B9?= =?UTF-8?q?=E6=A3=80=E6=B5=8B=E6=A8=A1=E5=9E=8B=20(#8066)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support COCO Whole Bady Hand * update transforms * disable `AugmentationbyInformantionDropping` * fix infer bug * fix getImgIds --- .../tiny_pose/tinypose_256x256_hand.yml | 145 +++++++++++++++ ppdet/data/source/category.py | 6 +- ppdet/data/source/keypoint_coco.py | 116 ++++++++++++ ppdet/data/transform/keypoint_operators.py | 129 ++++++++++++++ ppdet/engine/trainer.py | 17 +- ppdet/metrics/keypoint_metrics.py | 165 +++++++++++++++++- ppdet/modeling/keypoint_utils.py | 148 ++++++++++++++++ 7 files changed, 720 insertions(+), 6 deletions(-) create mode 100644 configs/keypoint/tiny_pose/tinypose_256x256_hand.yml diff --git a/configs/keypoint/tiny_pose/tinypose_256x256_hand.yml b/configs/keypoint/tiny_pose/tinypose_256x256_hand.yml new file mode 100644 index 000000000..db691f06b --- /dev/null +++ b/configs/keypoint/tiny_pose/tinypose_256x256_hand.yml @@ -0,0 +1,145 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/tinypose_256x256_hand/model_final +epoch: 210 +num_joints: &num_joints 21 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOWholeBadyHandEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 256 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [64, 64] +flip_perm: &flip_perm [] + + +#####model +architecture: TopDownHRNet + +TopDownHRNet: + backbone: LiteHRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 40 + loss: KeyPointMSELoss + use_dark: true + +LiteHRNet: + network_type: wider_naive + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + + +#####optimizer +LearningRate: + base_lr: 0.002 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoWholeBodyHandDataset + image_dir: train2017 + anno_path: annotations/coco_wholebody_train_v1.0.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + +EvalDataset: + !KeypointTopDownCocoWholeBodyHandDataset + image_dir: val2017 + anno_path: annotations/coco_wholebody_val_v1.0.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - TopDownRandomShiftBboxCenter: + shift_prob: 0.3 + shift_factor: 0.16 + - TopDownRandomFlip: + flip_prob: 0.5 + flip_perm: *flip_perm + - TopDownGetRandomScaleRotation: + rot_prob: 0.6 + rot_factor: 90 + scale_factor: 0.3 + # - AugmentationbyInformantionDropping: + # prob_cutout: 0.5 + # offset_factor: 0.05 + # num_patch: 1 + # trainsize: *trainsize + - TopDownAffine: + trainsize: *trainsize + use_udp: true + - ToHeatmapsTopDown_DARK: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 128 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + use_udp: true + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 128 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: false diff --git a/ppdet/data/source/category.py b/ppdet/data/source/category.py index 4da25a2d2..8ed1f9e04 100644 --- a/ppdet/data/source/category.py +++ b/ppdet/data/source/category.py @@ -114,8 +114,10 @@ def get_categories(metric_type, anno_file=None, arch=None): elif metric_type.lower() == 'widerface': return _widerface_category() - elif metric_type.lower() == 'keypointtopdowncocoeval' or metric_type.lower( - ) == 'keypointtopdownmpiieval': + elif metric_type.lower() in [ + 'keypointtopdowncocoeval', 'keypointtopdownmpiieval', + 'keypointtopdowncocowholebadyhandeval' + ]: return (None, {'id': 'keypoint'}) elif metric_type.lower() == 'pose3deval': diff --git a/ppdet/data/source/keypoint_coco.py b/ppdet/data/source/keypoint_coco.py index 6e072dc6e..86d83439b 100644 --- a/ppdet/data/source/keypoint_coco.py +++ b/ppdet/data/source/keypoint_coco.py @@ -635,6 +635,122 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): return kpt_db +@register +@serializable +class KeypointTopDownCocoWholeBodyHandDataset(KeypointTopDownBaseDataset): + """CocoWholeBody dataset for top-down hand pose estimation. + + The dataset loads raw features and apply specified transforms + to return a dict containing the image tensors and other information. + + COCO-WholeBody Hand keypoint indexes: + + 0: 'wrist', + 1: 'thumb1', + 2: 'thumb2', + 3: 'thumb3', + 4: 'thumb4', + 5: 'forefinger1', + 6: 'forefinger2', + 7: 'forefinger3', + 8: 'forefinger4', + 9: 'middle_finger1', + 10: 'middle_finger2', + 11: 'middle_finger3', + 12: 'middle_finger4', + 13: 'ring_finger1', + 14: 'ring_finger2', + 15: 'ring_finger3', + 16: 'ring_finger4', + 17: 'pinky_finger1', + 18: 'pinky_finger2', + 19: 'pinky_finger3', + 20: 'pinky_finger4' + + Args: + dataset_dir (str): Root path to the dataset. + image_dir (str): Path to a directory where images are held. + anno_path (str): Relative path to the annotation file. + num_joints (int): Keypoint numbers + trainsize (list):[w, h] Image target size + transform (composed(operators)): A sequence of data transforms. + pixel_std (int): The pixel std of the scale + Default: 200. + """ + + def __init__(self, + dataset_dir, + image_dir, + anno_path, + num_joints, + trainsize, + transform=[], + pixel_std=200): + super().__init__(dataset_dir, image_dir, anno_path, num_joints, + transform) + + self.trainsize = trainsize + self.pixel_std = pixel_std + self.dataset_name = 'coco_wholebady_hand' + + def _box2cs(self, box): + x, y, w, h = box[:4] + center = np.zeros((2), dtype=np.float32) + center[0] = x + w * 0.5 + center[1] = y + h * 0.5 + aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1] + + if w > aspect_ratio * h: + h = w * 1.0 / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + scale = np.array( + [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std], + dtype=np.float32) + if center[0] != -1: + scale = scale * 1.25 + + return center, scale + + def parse_dataset(self): + gt_db = [] + num_joints = self.ann_info['num_joints'] + coco = COCO(self.get_anno()) + img_ids = list(coco.imgs.keys()) + for img_id in img_ids: + im_ann = coco.loadImgs(img_id)[0] + image_file = os.path.join(self.img_prefix, im_ann['file_name']) + im_id = int(im_ann["id"]) + + ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + objs = coco.loadAnns(ann_ids) + + for obj in objs: + for type in ['left', 'right']: + if (obj[f'{type}hand_valid'] and + max(obj[f'{type}hand_kpts']) > 0): + + joints = np.zeros((num_joints, 3), dtype=np.float32) + joints_vis = np.zeros((num_joints, 3), dtype=np.float32) + + keypoints = np.array(obj[f'{type}hand_kpts']) + keypoints = keypoints.reshape(-1, 3) + joints[:, :2] = keypoints[:, :2] + joints_vis[:, :2] = np.minimum(1, keypoints[:, 2:3]) + + center, scale = self._box2cs(obj[f'{type}hand_box'][:4]) + gt_db.append({ + 'image_file': image_file, + 'center': center, + 'scale': scale, + 'gt_joints': joints, + 'joints_vis': joints_vis, + 'im_id': im_id, + }) + + self.db = gt_db + + @register @serializable class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): diff --git a/ppdet/data/transform/keypoint_operators.py b/ppdet/data/transform/keypoint_operators.py index fea23d696..d29aa2397 100644 --- a/ppdet/data/transform/keypoint_operators.py +++ b/ppdet/data/transform/keypoint_operators.py @@ -38,6 +38,7 @@ registered_ops = [] __all__ = [ 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps', 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform', + 'TopDownRandomFlip', 'TopDownRandomShiftBboxCenter', 'TopDownGetRandomScaleRotation', 'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK', 'ToHeatmapsTopDown_UDP', 'TopDownEvalAffine', 'AugmentationbyInformantionDropping', 'SinglePoseAffine', 'NoiseJitter', @@ -687,6 +688,134 @@ class AugmentationbyInformantionDropping(object): return records +@register_keypointop +class TopDownRandomFlip(object): + """Data augmentation with random image flip. + + Args: + flip_perm: (list[tuple]): Pairs of keypoints which are mirrored + (for example, left ear and right ear). + flip_prob (float): Probability of flip. + """ + + def __init__(self, flip_perm=[], flip_prob=0.5): + self.flip_perm = flip_perm + self.flip_prob = flip_prob + + def flip_joints(self, joints_3d, joints_3d_visible, img_width, flip_pairs): + assert len(joints_3d) == len(joints_3d_visible) + assert img_width > 0 + + joints_3d_flipped = joints_3d.copy() + joints_3d_visible_flipped = joints_3d_visible.copy() + + # Swap left-right parts + for left, right in flip_pairs: + joints_3d_flipped[left, :] = joints_3d[right, :] + joints_3d_flipped[right, :] = joints_3d[left, :] + + joints_3d_visible_flipped[left, :] = joints_3d_visible[right, :] + joints_3d_visible_flipped[right, :] = joints_3d_visible[left, :] + + # Flip horizontally + joints_3d_flipped[:, 0] = img_width - 1 - joints_3d_flipped[:, 0] + joints_3d_flipped = joints_3d_flipped * (joints_3d_visible_flipped > 0) + + return joints_3d_flipped, joints_3d_visible_flipped + + def __call__(self, results): + """Perform data augmentation with random image flip.""" + if np.random.rand() <= self.flip_prob: + return results + + img = results['image'] + joints_3d = results['gt_joints'] + joints_3d_visible = results['joints_vis'] + center = results['center'] + + # A flag indicating whether the image is flipped, + # which can be used by child class. + if not isinstance(img, list): + img = img[:, ::-1, :] + else: + img = [i[:, ::-1, :] for i in img] + if not isinstance(img, list): + joints_3d, joints_3d_visible = self.flip_joints( + joints_3d, joints_3d_visible, img.shape[1], + self.flip_perm) + center[0] = img.shape[1] - center[0] - 1 + else: + joints_3d, joints_3d_visible = self.flip_joints( + joints_3d, joints_3d_visible, img[0].shape[1], + self.flip_perm) + center[0] = img[0].shape[1] - center[0] - 1 + + results['image'] = img + results['gt_joints'] = joints_3d + results['joints_vis'] = joints_3d_visible + results['center'] = center + + return results + + +@register_keypointop +class TopDownRandomShiftBboxCenter(object): + """Random shift the bbox center. + + Args: + shift_factor (float): The factor to control the shift range, which is + scale*pixel_std*scale_factor. Default: 0.16 + shift_prob (float): Probability of applying random shift. Default: 0.3 + """ + + def __init__(self, shift_factor=0.16, shift_prob=0.3): + self.shift_factor = shift_factor + self.shift_prob = shift_prob + + def __call__(self, results): + center = results['center'] + scale = results['scale'] + if np.random.rand() < self.shift_prob: + center += np.random.uniform( + -1, 1, 2) * self.shift_factor * scale * 200.0 + + results['center'] = center + return results + +@register_keypointop +class TopDownGetRandomScaleRotation(object): + """Data augmentation with random scaling & rotating. + + Args: + rot_factor (int): Rotating to ``[-2*rot_factor, 2*rot_factor]``. + scale_factor (float): Scaling to ``[1-scale_factor, 1+scale_factor]``. + rot_prob (float): Probability of random rotation. + """ + + def __init__(self, rot_factor=40, scale_factor=0.5, rot_prob=0.6): + self.rot_factor = rot_factor + self.scale_factor = scale_factor + self.rot_prob = rot_prob + + def __call__(self, results): + """Perform data augmentation with random scaling & rotating.""" + s = results['scale'] + + sf = self.scale_factor + rf = self.rot_factor + + s_factor = np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) + s = s * s_factor + + r_factor = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) + r = r_factor if np.random.rand() <= self.rot_prob else 0 + + results['scale'] = s + results['rotate'] = r + + return results + + @register_keypointop class TopDownAffine(object): """apply affine transform to image and coords diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index f022793b6..bfd92fd62 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -38,8 +38,8 @@ from ppdet.optimizer import ModelEMA from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result -from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval, Pose3DEval -from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric +from ppdet.metrics import get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownCOCOWholeBadyHandEval, KeyPointTopDownMPIIEval, Pose3DEval +from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, RBoxMetric, JDEDetMetric, SNIPERCOCOMetric from ppdet.data.source.sniper_coco import SniperCOCODataSet from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats @@ -348,6 +348,19 @@ class Trainer(object): self.cfg.save_dir, save_prediction_only=save_prediction_only) ] + elif self.cfg.metric == 'KeyPointTopDownCOCOWholeBadyHandEval': + eval_dataset = self.cfg['EvalDataset'] + eval_dataset.check_or_download_dataset() + anno_file = eval_dataset.get_anno() + save_prediction_only = self.cfg.get('save_prediction_only', False) + self._metrics = [ + KeyPointTopDownCOCOWholeBadyHandEval( + anno_file, + len(eval_dataset), + self.cfg.num_joints, + self.cfg.save_dir, + save_prediction_only=save_prediction_only) + ] elif self.cfg.metric == 'KeyPointTopDownMPIIEval': eval_dataset = self.cfg['EvalDataset'] eval_dataset.check_or_download_dataset() diff --git a/ppdet/metrics/keypoint_metrics.py b/ppdet/metrics/keypoint_metrics.py index cbd52d02d..26e9ecb5e 100644 --- a/ppdet/metrics/keypoint_metrics.py +++ b/ppdet/metrics/keypoint_metrics.py @@ -19,12 +19,15 @@ import numpy as np import paddle from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval -from ..modeling.keypoint_utils import oks_nms +from ..modeling.keypoint_utils import oks_nms, keypoint_pck_accuracy, keypoint_auc, keypoint_epe from scipy.io import loadmat, savemat from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) -__all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval'] +__all__ = [ + 'KeyPointTopDownCOCOEval', 'KeyPointTopDownCOCOWholeBadyHandEval', + 'KeyPointTopDownMPIIEval' +] class KeyPointTopDownCOCOEval(object): @@ -226,6 +229,164 @@ class KeyPointTopDownCOCOEval(object): return self.eval_results +class KeyPointTopDownCOCOWholeBadyHandEval(object): + def __init__(self, + anno_file, + num_samples, + num_joints, + output_eval, + save_prediction_only=False): + super(KeyPointTopDownCOCOWholeBadyHandEval, self).__init__() + self.coco = COCO(anno_file) + self.num_samples = num_samples + self.num_joints = num_joints + self.output_eval = output_eval + self.res_file = os.path.join(output_eval, "keypoints_results.json") + self.save_prediction_only = save_prediction_only + self.parse_dataset() + self.reset() + + def parse_dataset(self): + gt_db = [] + num_joints = self.num_joints + coco = self.coco + img_ids = coco.getImgIds() + for img_id in img_ids: + ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + objs = coco.loadAnns(ann_ids) + + for obj in objs: + for type in ['left', 'right']: + if (obj[f'{type}hand_valid'] and + max(obj[f'{type}hand_kpts']) > 0): + + joints = np.zeros((num_joints, 3), dtype=np.float32) + joints_vis = np.zeros((num_joints, 3), dtype=np.float32) + + keypoints = np.array(obj[f'{type}hand_kpts']) + keypoints = keypoints.reshape(-1, 3) + joints[:, :2] = keypoints[:, :2] + joints_vis[:, :2] = np.minimum(1, keypoints[:, 2:3]) + + gt_db.append({ + 'bbox': obj[f'{type}hand_box'], + 'gt_joints': joints, + 'joints_vis': joints_vis, + }) + self.db = gt_db + + def reset(self): + self.results = { + 'preds': np.zeros( + (self.num_samples, self.num_joints, 3), dtype=np.float32), + } + self.eval_results = {} + self.idx = 0 + + def update(self, inputs, outputs): + kpts, _ = outputs['keypoint'][0] + num_images = inputs['image'].shape[0] + self.results['preds'][self.idx:self.idx + num_images, :, 0: + 3] = kpts[:, :, 0:3] + self.idx += num_images + + def accumulate(self): + self.get_final_results(self.results['preds']) + if self.save_prediction_only: + logger.info(f'The keypoint result is saved to {self.res_file} ' + 'and do not evaluate the mAP.') + return + + self.eval_results = self.evaluate(self.res_file, ('PCK', 'AUC', 'EPE')) + + def get_final_results(self, preds): + kpts = [] + for idx, kpt in enumerate(preds): + kpts.append({'keypoints': kpt.tolist()}) + + self._write_keypoint_results(kpts) + + def _write_keypoint_results(self, keypoints): + if not os.path.exists(self.output_eval): + os.makedirs(self.output_eval) + with open(self.res_file, 'w') as f: + json.dump(keypoints, f, sort_keys=True, indent=4) + logger.info(f'The keypoint result is saved to {self.res_file}.') + try: + json.load(open(self.res_file)) + except Exception: + content = [] + with open(self.res_file, 'r') as f: + for line in f: + content.append(line) + content[-1] = ']' + with open(self.res_file, 'w') as f: + for c in content: + f.write(c) + + def log(self): + if self.save_prediction_only: + return + for item, value in self.eval_results.items(): + print("{} : {}".format(item, value)) + + def get_results(self): + return self.eval_results + + def evaluate(self, res_file, metrics, pck_thr=0.2, auc_nor=30): + """Keypoint evaluation. + + Args: + res_file (str): Json file stored prediction results. + metrics (str | list[str]): Metric to be performed. + Options: 'PCK', 'AUC', 'EPE'. + pck_thr (float): PCK threshold, default as 0.2. + auc_nor (float): AUC normalization factor, default as 30 pixel. + + Returns: + List: Evaluation results for evaluation metric. + """ + info_str = [] + + with open(res_file, 'r') as fin: + preds = json.load(fin) + assert len(preds) == len(self.db) + + outputs = [] + gts = [] + masks = [] + threshold_bbox = [] + + for pred, item in zip(preds, self.db): + outputs.append(np.array(pred['keypoints'])[:, :-1]) + gts.append(np.array(item['gt_joints'])[:, :-1]) + masks.append((np.array(item['joints_vis'])[:, 0]) > 0) + if 'PCK' in metrics: + bbox = np.array(item['bbox']) + bbox_thr = np.max(bbox[2:]) + threshold_bbox.append(np.array([bbox_thr, bbox_thr])) + + outputs = np.array(outputs) + gts = np.array(gts) + masks = np.array(masks) + threshold_bbox = np.array(threshold_bbox) + + if 'PCK' in metrics: + _, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr, + threshold_bbox) + info_str.append(('PCK', pck)) + + if 'AUC' in metrics: + info_str.append(('AUC', keypoint_auc(outputs, gts, masks, auc_nor))) + + if 'EPE' in metrics: + info_str.append(('EPE', keypoint_epe(outputs, gts, masks))) + + name_value = OrderedDict(info_str) + + return name_value + + class KeyPointTopDownMPIIEval(object): def __init__(self, anno_file, diff --git a/ppdet/modeling/keypoint_utils.py b/ppdet/modeling/keypoint_utils.py index 377f1d75c..382e37317 100644 --- a/ppdet/modeling/keypoint_utils.py +++ b/ppdet/modeling/keypoint_utils.py @@ -401,3 +401,151 @@ def flip_back(output_flipped, flip_pairs, target_type='GaussianHeatmap'): # Flip horizontally output_flipped_back = output_flipped_back[..., ::-1] return output_flipped_back + + +def _calc_distances(preds, targets, mask, normalize): + """Calculate the normalized distances between preds and target. + + Note: + batch_size: N + num_keypoints: K + dimension of keypoints: D (normally, D=2 or D=3) + + Args: + preds (np.ndarray[N, K, D]): Predicted keypoint location. + targets (np.ndarray[N, K, D]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (np.ndarray[N, D]): Typical value is heatmap_size + + Returns: + np.ndarray[K, N]: The normalized distances. \ + If target keypoints are missing, the distance is -1. + """ + N, K, _ = preds.shape + # set mask=0 when normalize==0 + _mask = mask.copy() + _mask[np.where((normalize == 0).sum(1))[0], :] = False + distances = np.full((N, K), -1, dtype=np.float32) + # handle invalid values + normalize[np.where(normalize <= 0)] = 1e6 + distances[_mask] = np.linalg.norm( + ((preds - targets) / normalize[:, None, :])[_mask], axis=-1) + return distances.T + + +def _distance_acc(distances, thr=0.5): + """Return the percentage below the distance threshold, while ignoring + distances values with -1. + + Note: + batch_size: N + Args: + distances (np.ndarray[N, ]): The normalized distances. + thr (float): Threshold of the distances. + + Returns: + float: Percentage of distances below the threshold. \ + If all target keypoints are missing, return -1. + """ + distance_valid = distances != -1 + num_distance_valid = distance_valid.sum() + if num_distance_valid > 0: + return (distances[distance_valid] < thr).sum() / num_distance_valid + return -1 + + +def keypoint_pck_accuracy(pred, gt, mask, thr, normalize): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - acc (np.ndarray[K]): Accuracy of each keypoint. + - avg_acc (float): Averaged accuracy across all keypoints. + - cnt (int): Number of valid keypoints. + """ + distances = _calc_distances(pred, gt, mask, normalize) + + acc = np.array([_distance_acc(d, thr) for d in distances]) + valid_acc = acc[acc >= 0] + cnt = len(valid_acc) + avg_acc = valid_acc.mean() if cnt > 0 else 0 + return acc, avg_acc, cnt + + +def keypoint_auc(pred, gt, mask, normalize, num_step=20): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (float): Normalization factor. + + Returns: + float: Area under curve. + """ + nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1)) + x = [1.0 * i / num_step for i in range(num_step)] + y = [] + for thr in x: + _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor) + y.append(avg_acc) + + auc = 0 + for i in range(num_step): + auc += 1.0 / num_step * y[i] + return auc + + +def keypoint_epe(pred, gt, mask): + """Calculate the end-point error. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + + Returns: + float: Average end-point error. + """ + + normalize = np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32) + distances = _calc_distances(pred, gt, mask, normalize) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) -- GitLab