From f0ae1b64aeef499081e93bd528885000fbdce4a4 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Tue, 27 Sep 2022 12:55:18 +0800 Subject: [PATCH] pose3d part3: train and metric (#6613) * pose3d train * delete extra comments * only pose3d parallel eval; delete pose3d config * fix mistake edit --- ppdet/engine/callbacks.py | 86 ++++++++------ ppdet/engine/export_utils.py | 1 + ppdet/engine/trainer.py | 41 +++++-- ppdet/metrics/__init__.py | 1 + ppdet/metrics/coco_utils.py | 8 +- ppdet/metrics/json_results.py | 16 +++ ppdet/metrics/pose3d_metrics.py | 202 ++++++++++++++++++++++++++++++++ ppdet/utils/visualizer.py | 134 +++++++++++++++++++++ 8 files changed, 443 insertions(+), 46 deletions(-) create mode 100644 ppdet/metrics/pose3d_metrics.py diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 09683d18b..14ece8ceb 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -160,7 +160,7 @@ class Checkpointer(Callback): def __init__(self, model): super(Checkpointer, self).__init__(model) cfg = self.model.cfg - self.best_ap = 0. + self.best_ap = -1000. self.save_dir = os.path.join(self.model.cfg.save_dir, self.model.cfg.filename) if hasattr(self.model.model, 'student_model'): @@ -187,7 +187,11 @@ class Checkpointer(Callback): if 'save_best_model' in status and status['save_best_model']: for metric in self.model._metrics: map_res = metric.get_results() - if 'bbox' in map_res: + eval_func = "ap" + if 'pose3d' in map_res: + key = 'pose3d' + eval_func = "mpjpe" + elif 'bbox' in map_res: key = 'bbox' elif 'keypoint' in map_res: key = 'keypoint' @@ -202,8 +206,8 @@ class Checkpointer(Callback): self.best_ap = map_res[key][0] save_name = 'best_model' weight = self.weight.state_dict() - logger.info("Best test {} ap is {:0.3f}.".format( - key, self.best_ap)) + logger.info("Best test {} {} is {:0.3f}.".format( + key, eval_func, abs(self.best_ap))) if weight: if self.model.use_ema: # save model and ema_model @@ -288,6 +292,7 @@ class VisualDLWriter(Callback): self.vdl_mAP_step) self.vdl_mAP_step += 1 + class WandbCallback(Callback): def __init__(self, model): super(WandbCallback, self).__init__(model) @@ -307,10 +312,8 @@ class WandbCallback(Callback): self.wandb_params = {} for k, v in model.cfg.items(): if k.startswith("wandb_"): - self.wandb_params.update({ - k.lstrip("wandb_"): v - }) - + self.wandb_params.update({k.lstrip("wandb_"): v}) + self._run = None if dist.get_world_size() < 2 or dist.get_rank() == 0: _ = self.run @@ -318,28 +321,29 @@ class WandbCallback(Callback): self.run.define_metric("epoch") self.run.define_metric("eval/*", step_metric="epoch") - self.best_ap = 0 - + self.best_ap = -1000. + @property def run(self): if self._run is None: if self.wandb.run is not None: - logger.info("There is an ongoing wandb run which will be used" - "for logging. Please use `wandb.finish()` to end that" - "if the behaviour is not intended") + logger.info( + "There is an ongoing wandb run which will be used" + "for logging. Please use `wandb.finish()` to end that" + "if the behaviour is not intended") self._run = self.wandb.run else: self._run = self.wandb.init(**self.wandb_params) return self._run - + def save_model(self, - optimizer, - save_dir, - save_name, - last_epoch, - ema_model=None, - ap=None, - tags=None): + optimizer, + save_dir, + save_name, + last_epoch, + ema_model=None, + ap=None, + tags=None): if dist.get_world_size() < 2 or dist.get_rank() == 0: model_path = os.path.join(save_dir, save_name) metadata = {} @@ -347,8 +351,14 @@ class WandbCallback(Callback): if ap: metadata["ap"] = ap if ema_model is None: - ema_artifact = self.wandb.Artifact(name="ema_model-{}".format(self.run.id), type="model", metadata=metadata) - model_artifact = self.wandb.Artifact(name="model-{}".format(self.run.id), type="model", metadata=metadata) + ema_artifact = self.wandb.Artifact( + name="ema_model-{}".format(self.run.id), + type="model", + metadata=metadata) + model_artifact = self.wandb.Artifact( + name="model-{}".format(self.run.id), + type="model", + metadata=metadata) ema_artifact.add_file(model_path + ".pdema", name="model_ema") model_artifact.add_file(model_path + ".pdparams", name="model") @@ -356,10 +366,13 @@ class WandbCallback(Callback): self.run.log_artifact(ema_artifact, aliases=tags) self.run.log_artfact(model_artifact, aliases=tags) else: - model_artifact = self.wandb.Artifact(name="model-{}".format(self.run.id), type="model", metadata=metadata) + model_artifact = self.wandb.Artifact( + name="model-{}".format(self.run.id), + type="model", + metadata=metadata) model_artifact.add_file(model_path + ".pdparams", name="model") self.run.log_artifact(model_artifact, aliases=tags) - + def on_step_end(self, status): mode = status['mode'] @@ -368,11 +381,9 @@ class WandbCallback(Callback): training_status = status['training_staus'].get() for k, v in training_status.items(): training_status[k] = float(v) - metrics = { - "train/" + k: v for k,v in training_status.items() - } + metrics = {"train/" + k: v for k, v in training_status.items()} self.run.log(metrics) - + def on_epoch_end(self, status): mode = status['mode'] epoch_id = status['epoch_id'] @@ -383,7 +394,8 @@ class WandbCallback(Callback): if ( epoch_id + 1 ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: - save_name = str(epoch_id) if epoch_id != end_epoch - 1 else "model_final" + save_name = str( + epoch_id) if epoch_id != end_epoch - 1 else "model_final" tags = ["latest", "epoch_{}".format(epoch_id)] self.save_model( self.model.optimizer, @@ -391,8 +403,7 @@ class WandbCallback(Callback): save_name, epoch_id + 1, self.model.use_ema, - tags=tags - ) + tags=tags) if mode == 'eval': merged_dict = {} for metric in self.model._metrics: @@ -404,7 +415,9 @@ class WandbCallback(Callback): if 'save_best_model' in status and status['save_best_model']: for metric in self.model._metrics: map_res = metric.get_results() - if 'bbox' in map_res: + if 'pose3d' in map_res: + key = 'pose3d' + elif 'bbox' in map_res: key = 'bbox' elif 'keypoint' in map_res: key = 'keypoint' @@ -426,10 +439,9 @@ class WandbCallback(Callback): save_name, last_epoch=epoch_id + 1, ema_model=self.model.use_ema, - ap=self.best_ap, - tags=tags - ) - + ap=abs(self.best_ap), + tags=tags) + def on_train_end(self, status): self.run.finish() diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index f9f90a9b7..0f71ee6f5 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -49,6 +49,7 @@ TRT_MIN_SUBGRAPH = { 'CenterNet': 5, 'TOOD': 5, 'YOLOX': 8, + 'METRO_Body': 3, } KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 53d7296a0..2ce783c5a 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -38,7 +38,7 @@ from ppdet.optimizer import ModelEMA from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result -from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval +from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval, Pose3DEval from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric from ppdet.data.source.sniper_coco import SniperCOCODataSet from ppdet.data.source.category import get_categories @@ -136,6 +136,9 @@ class Trainer(object): if self.mode == 'eval': if cfg.architecture == 'FairMOT': self.loader = create('EvalMOTReader')(self.dataset, 0) + elif cfg.architecture == "METRO_Body": + reader_name = '{}Reader'.format(self.mode.capitalize()) + self.loader = create(reader_name)(self.dataset, cfg.worker_num) else: self._eval_batch_sampler = paddle.io.BatchSampler( self.dataset, batch_size=self.cfg.EvalReader['batch_size']) @@ -342,6 +345,13 @@ class Trainer(object): self.cfg.save_dir, save_prediction_only=save_prediction_only) ] + elif self.cfg.metric == 'Pose3DEval': + save_prediction_only = self.cfg.get('save_prediction_only', False) + self._metrics = [ + Pose3DEval( + self.cfg.save_dir, + save_prediction_only=save_prediction_only) + ] elif self.cfg.metric == 'MOTDet': self._metrics = [JDEDetMetric(), ] else: @@ -450,6 +460,7 @@ class Trainer(object): self.loader.dataset.set_epoch(epoch_id) model.train() iter_tic = time.time() + print("loader len:", len(self.loader)) for step_id, data in enumerate(self.loader): self.status['data_time'].update(time.time() - iter_tic) self.status['step_id'] = step_id @@ -537,7 +548,7 @@ class Trainer(object): self._compose_callback.on_epoch_end(self.status) - if validate and is_snapshot: + if validate: if not hasattr(self, '_eval_loader'): # build evaluation dataset and loader self._eval_dataset = self.cfg.EvalDataset @@ -548,10 +559,14 @@ class Trainer(object): # If metric is VOC, need to be set collate_batch=False. if self.cfg.metric == 'VOC': self.cfg['EvalReader']['collate_batch'] = False - self._eval_loader = create('EvalReader')( - self._eval_dataset, - self.cfg.worker_num, - batch_sampler=self._eval_batch_sampler) + if self.cfg.metric == "Pose3DEval": + self._eval_loader = create('EvalReader')( + self._eval_dataset, self.cfg.worker_num) + else: + self._eval_loader = create('EvalReader')( + self._eval_dataset, + self.cfg.worker_num, + batch_sampler=self._eval_batch_sampler) # if validation in training is enabled, metrics should be re-init # Init_mark makes sure this code will only execute once if validate and Init_mark == False: @@ -575,6 +590,7 @@ class Trainer(object): tic = time.time() self._compose_callback.on_epoch_begin(self.status) self.status['mode'] = 'eval' + self.model.eval() if self.cfg.get('print_flops', False): flops_loader = create('{}Reader'.format(self.mode.capitalize()))( @@ -617,6 +633,15 @@ class Trainer(object): self._reset_metrics() def evaluate(self): + # get distributed model + if self.cfg.get('fleet', False): + self.model = fleet.distributed_model(self.model) + self.optimizer = fleet.distributed_optimizer(self.optimizer) + elif self._nranks > 1: + find_unused_parameters = self.cfg[ + 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False + self.model = paddle.DataParallel( + self.model, find_unused_parameters=find_unused_parameters) with paddle.no_grad(): self._eval_with_loader(self.loader) @@ -921,9 +946,11 @@ class Trainer(object): if 'segm' in batch_res else None keypoint_res = batch_res['keypoint'][start:end] \ if 'keypoint' in batch_res else None + pose3d_res = batch_res['pose3d'][start:end] \ + if 'pose3d' in batch_res else None image = visualize_results( image, bbox_res, mask_res, segm_res, keypoint_res, - int(im_id), catid2name, draw_threshold) + pose3d_res, int(im_id), catid2name, draw_threshold) self.status['result_image'] = np.array(image.copy()) if self._compose_callback: self._compose_callback.on_step_end(self.status) diff --git a/ppdet/metrics/__init__.py b/ppdet/metrics/__init__.py index d69e8af0f..3e1b83cca 100644 --- a/ppdet/metrics/__init__.py +++ b/ppdet/metrics/__init__.py @@ -17,6 +17,7 @@ from . import keypoint_metrics from .metrics import * from .keypoint_metrics import * +from .pose3d_metrics import * __all__ = metrics.__all__ + keypoint_metrics.__all__ diff --git a/ppdet/metrics/coco_utils.py b/ppdet/metrics/coco_utils.py index 47b92bc62..b7a4d7e32 100644 --- a/ppdet/metrics/coco_utils.py +++ b/ppdet/metrics/coco_utils.py @@ -21,7 +21,7 @@ import sys import numpy as np import itertools -from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res, get_keypoint_res +from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res, get_keypoint_res, get_pose3d_res from ppdet.metrics.map_utils import draw_pr_curve from ppdet.utils.logger import setup_logger @@ -64,6 +64,10 @@ def get_infer_results(outs, catid, bias=0): infer_res['keypoint'] = get_keypoint_res(outs, im_id) outs['bbox_num'] = [len(infer_res['keypoint'])] + if 'pose3d' in outs: + infer_res['pose3d'] = get_pose3d_res(outs, im_id) + outs['bbox_num'] = [len(infer_res['pose3d'])] + return infer_res @@ -150,7 +154,7 @@ def cocoapi_eval(jsonfile, results_flatten = list(itertools.chain(*results_per_category)) headers = ['category', 'AP'] * (num_columns // 2) results_2d = itertools.zip_longest( - *[results_flatten[i::num_columns] for i in range(num_columns)]) + * [results_flatten[i::num_columns] for i in range(num_columns)]) table_data = [headers] table_data += [result for result in results_2d] table = AsciiTable(table_data) diff --git a/ppdet/metrics/json_results.py b/ppdet/metrics/json_results.py index 93354ec1f..d2575af43 100755 --- a/ppdet/metrics/json_results.py +++ b/ppdet/metrics/json_results.py @@ -157,3 +157,19 @@ def get_keypoint_res(results, im_id): ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] anns.append(ann) return anns + + +def get_pose3d_res(results, im_id): + anns = [] + preds = results['pose3d'] + for idx in range(im_id.shape[0]): + image_id = im_id[idx].item() + pose3d = preds[idx] + ann = { + 'image_id': image_id, + 'category_id': 1, # XXX hard code + 'pose3d': pose3d.tolist(), + 'score': float(1.) + } + anns.append(ann) + return anns diff --git a/ppdet/metrics/pose3d_metrics.py b/ppdet/metrics/pose3d_metrics.py new file mode 100644 index 000000000..45b9239a5 --- /dev/null +++ b/ppdet/metrics/pose3d_metrics.py @@ -0,0 +1,202 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import os +import json +from collections import defaultdict, OrderedDict +import numpy as np +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = ['Pose3DEval'] + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def mean_per_joint_position_error(pred, gt, has_3d_joints): + """ + Compute mPJPE + """ + gt = gt[has_3d_joints == 1] + gt = gt[:, :, :3] + pred = pred[has_3d_joints == 1] + + with paddle.no_grad(): + gt_pelvis = (gt[:, 2, :] + gt[:, 3, :]) / 2 + gt = gt - gt_pelvis[:, None, :] + pred_pelvis = (pred[:, 2, :] + pred[:, 3, :]) / 2 + pred = pred - pred_pelvis[:, None, :] + error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean(axis=-1).numpy() + return error + + +def compute_similarity_transform(S1, S2): + """Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.T + S2 = S2.T + transposed = True + assert (S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=1, keepdims=True) + mu2 = S2.mean(axis=1, keepdims=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = np.sum(X1**2) + + # 3. The outer product of X1 and X2. + K = X1.dot(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, Vh = np.linalg.svd(K) + V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = np.eye(U.shape[0]) + Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) + # Construct R. + R = V.dot(Z.dot(U.T)) + + # 5. Recover scale. + scale = np.trace(R.dot(K)) / var1 + + # 6. Recover translation. + t = mu2 - scale * (R.dot(mu1)) + + # 7. Error: + S1_hat = scale * R.dot(S1) + t + + if transposed: + S1_hat = S1_hat.T + + return S1_hat + + +def compute_similarity_transform_batch(S1, S2): + """Batched version of compute_similarity_transform.""" + S1_hat = np.zeros_like(S1) + for i in range(S1.shape[0]): + S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) + return S1_hat + + +def reconstruction_error(S1, S2, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + re = np.sqrt(((S1_hat - S2)**2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re + + +def all_gather(data): + if paddle.distributed.get_world_size() == 1: + return data + vlist = [] + paddle.distributed.all_gather(vlist, data) + data = paddle.concat(vlist, 0) + return data + + +class Pose3DEval(object): + """refer to + https://github.com/leoxiaobin/deep-high-resolution-net.pytorch + Copyright (c) Microsoft, under the MIT License. + """ + + def __init__(self, output_eval, save_prediction_only=False): + super(Pose3DEval, self).__init__() + self.output_eval = output_eval + self.res_file = os.path.join(output_eval, "pose3d_results.json") + self.save_prediction_only = save_prediction_only + self.reset() + + def reset(self): + self.PAmPJPE = AverageMeter() + self.mPJPE = AverageMeter() + self.eval_results = {} + + def get_human36m_joints(self, input): + J24_TO_J14 = paddle.to_tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]) + J24_TO_J17 = paddle.to_tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19]) + return paddle.index_select(input, J24_TO_J14, axis=1) + + def update(self, inputs, outputs): + gt_3d_joints = all_gather(inputs['joints_3d']) + has_3d_joints = all_gather(inputs['has_3d_joints']) + pred_3d_joints = all_gather(outputs['pose3d']) + if gt_3d_joints.shape[1] == 24: + gt_3d_joints = self.get_human36m_joints(gt_3d_joints) + if pred_3d_joints.shape[1] == 24: + pred_3d_joints = self.get_human36m_joints(pred_3d_joints) + mPJPE_val = mean_per_joint_position_error(pred_3d_joints, gt_3d_joints, + has_3d_joints).mean() + PAmPJPE_val = reconstruction_error( + pred_3d_joints.numpy(), + gt_3d_joints[:, :, :3].numpy(), + reduction=None).mean() + count = int(np.sum(has_3d_joints.numpy())) + self.PAmPJPE.update(PAmPJPE_val * 1000., count) + self.mPJPE.update(mPJPE_val * 1000., count) + + def accumulate(self): + if self.save_prediction_only: + logger.info(f'The pose3d result is saved to {self.res_file} ' + 'and do not evaluate the model.') + return + self.eval_results['pose3d'] = [-self.mPJPE.avg, -self.PAmPJPE.avg] + + def log(self): + if self.save_prediction_only: + return + stats_names = ['mPJPE', 'PAmPJPE'] + num_values = len(stats_names) + print(' '.join(['| {}'.format(name) for name in stats_names]) + ' |') + print('|---' * (num_values + 1) + '|') + + print(' '.join([ + '| {:.3f}'.format(abs(value)) + for value in self.eval_results['pose3d'] + ]) + ' |') + + def get_results(self): + return self.eval_results diff --git a/ppdet/utils/visualizer.py b/ppdet/utils/visualizer.py index fdfd966e2..135180854 100644 --- a/ppdet/utils/visualizer.py +++ b/ppdet/utils/visualizer.py @@ -34,6 +34,7 @@ def visualize_results(image, mask_res, segm_res, keypoint_res, + pose3d_res, im_id, catid2name, threshold=0.5): @@ -48,6 +49,8 @@ def visualize_results(image, image = draw_segm(image, im_id, catid2name, segm_res, threshold) if keypoint_res is not None: image = draw_pose(image, keypoint_res, threshold) + if pose3d_res is not None: + image = draw_pose3d(image, pose3d_res, threshold) return image @@ -319,3 +322,134 @@ def draw_pose(image, image = Image.fromarray(canvas.astype('uint8')) plt.close() return image + + +def draw_pose3d(image, + results, + visual_thread=0.6, + save_name='pose3d.jpg', + save_dir='output', + returnimg=False, + ids=None): + try: + import matplotlib.pyplot as plt + import matplotlib + plt.switch_backend('agg') + except Exception as e: + logger.error('Matplotlib not found, please install matplotlib.' + 'for example: `pip install matplotlib`.') + raise e + pose3d = np.array(results[0]['pose3d']) * 1000 + + if pose3d.shape[0] == 24: + joints_connectivity_dict = [ + [0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 14, 1], + [3, 14, 1], [14, 15, 1], [15, 16, 1], [16, 12, 1], [6, 7, 0], + [7, 8, 0], [11, 10, 1], [10, 9, 1], [8, 12, 0], [9, 12, 1], + [12, 19, 1], [19, 18, 1], [19, 20, 0], [19, 21, 1], [22, 20, 0], + [23, 21, 1] + ] + elif pose3d.shape[0] == 14: + joints_connectivity_dict = [ + [0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 12, 0], + [3, 12, 1], [6, 7, 0], [7, 8, 0], [11, 10, 1], [10, 9, 1], + [8, 12, 0], [9, 12, 1], [12, 13, 1] + ] + else: + print( + "not defined joints number :{}, cannot visualize because unknown of joint connectivity". + format(pose.shape[0])) + return + + def draw3Dpose(pose3d, + ax, + lcolor="#3498db", + rcolor="#e74c3c", + add_labels=False): + # pose3d = orthographic_projection(pose3d, cam) + for i in joints_connectivity_dict: + x, y, z = [ + np.array([pose3d[i[0], j], pose3d[i[1], j]]) for j in range(3) + ] + ax.plot(-x, -z, -y, lw=2, c=lcolor if i[2] else rcolor) + + RADIUS = 1000 + center_xy = 2 if pose3d.shape[0] == 14 else 14 + x, y, z = pose3d[center_xy, 0], pose3d[center_xy, 1], pose3d[center_xy, + 2] + ax.set_xlim3d([-RADIUS + x, RADIUS + x]) + ax.set_ylim3d([-RADIUS + y, RADIUS + y]) + ax.set_zlim3d([-RADIUS + z, RADIUS + z]) + + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + + def draw2Dpose(pose2d, + ax, + lcolor="#3498db", + rcolor="#e74c3c", + add_labels=False): + for i in joints_connectivity_dict: + if pose2d[i[0], 2] and pose2d[i[1], 2]: + x, y = [ + np.array([pose2d[i[0], j], pose2d[i[1], j]]) + for j in range(2) + ] + ax.plot(x, y, 0, lw=2, c=lcolor if i[2] else rcolor) + + def draw_img_pose(pose3d, + pose2d=None, + frame=None, + figsize=(12, 12), + savepath=None): + fig = plt.figure(figsize=figsize, dpi=80) + # fig.clear() + fig.tight_layout() + + ax = fig.add_subplot(221) + if frame is not None: + ax.imshow(frame, interpolation='nearest') + if pose2d is not None: + draw2Dpose(pose2d, ax) + + ax = fig.add_subplot(222, projection='3d') + ax.view_init(45, 45) + draw3Dpose(pose3d, ax) + ax = fig.add_subplot(223, projection='3d') + ax.view_init(0, 0) + draw3Dpose(pose3d, ax) + ax = fig.add_subplot(224, projection='3d') + ax.view_init(0, 90) + draw3Dpose(pose3d, ax) + + if savepath is not None: + plt.savefig(savepath) + plt.close() + else: + return fig + + def fig2data(fig): + """ + fig = plt.figure() + image = fig2data(fig) + @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it + @param fig a matplotlib figure + @return a numpy 3D array of RGBA values + """ + # draw the renderer + fig.canvas.draw() + + # Get the RGBA buffer from the figure + w, h = fig.canvas.get_width_height() + buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) + buf.shape = (w, h, 4) + + # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode + buf = np.roll(buf, 3, axis=2) + image = Image.frombytes("RGBA", (w, h), buf.tostring()) + return image.convert("RGB") + + fig = draw_img_pose(pose3d, frame=image) + data = fig2data(fig) + return data -- GitLab