From fa1ba1aa9235ed14a65aba65ccc9c2615c9d4ec1 Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Thu, 9 Mar 2023 13:54:49 +0800 Subject: [PATCH] Add vitpose (#7894) * add vitpose * fix keypoints/README_en.md document_fix=test * add vitpose_base_coco_256x192 * vitpose.py add some annotation for discriminate visio_transformer.py * output_heatmap use gpu --- configs/keypoint/README.md | 13 +- configs/keypoint/README_en.md | 13 +- .../vit_pose/vitpose_base_coco_256x196.yml | 171 ++++++++++ .../vitpose_base_simple_coco_256x192.yml | 164 +++++++++ ppdet/data/source/keypoint_coco.py | 7 +- ppdet/modeling/architectures/__init__.py | 2 + .../architectures/keypoint_vitpose.py | 317 +++++++++++++++++ ppdet/modeling/backbones/__init__.py | 3 +- ppdet/modeling/backbones/vitpose.py | 320 ++++++++++++++++++ ppdet/modeling/heads/__init__.py | 2 + ppdet/modeling/heads/vitpose_head.py | 278 +++++++++++++++ ppdet/modeling/keypoint_utils.py | 61 ++++ ppdet/optimizer/adamw.py | 2 +- 13 files changed, 1348 insertions(+), 5 deletions(-) create mode 100644 configs/keypoint/vit_pose/vitpose_base_coco_256x196.yml create mode 100644 configs/keypoint/vit_pose/vitpose_base_simple_coco_256x192.yml create mode 100644 ppdet/modeling/architectures/keypoint_vitpose.py create mode 100644 ppdet/modeling/backbones/vitpose.py create mode 100644 ppdet/modeling/heads/vitpose_head.py diff --git a/configs/keypoint/README.md b/configs/keypoint/README.md index c93932d73..8b08f0920 100644 --- a/configs/keypoint/README.md +++ b/configs/keypoint/README.md @@ -72,8 +72,11 @@ COCO数据集 | LiteHRNet-18 |Top-Down| 384x288 | 69.7 | [lite_hrnet_18_384x288_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_18_384x288_coco.pdparams) | [config](./lite_hrnet/lite_hrnet_18_384x288_coco.yml) | | LiteHRNet-30 | Top-Down|256x192 | 69.4 | [lite_hrnet_30_256x192_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_30_256x192_coco.pdparams) | [config](./lite_hrnet/lite_hrnet_30_256x192_coco.yml) | | LiteHRNet-30 |Top-Down| 384x288 | 72.5 | [lite_hrnet_30_384x288_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_30_384x288_coco.pdparams) | [config](./lite_hrnet/lite_hrnet_30_384x288_coco.yml) | +|Vitpose_base_simple |Top-Down| 256x192 | 77.7 | [vitpose_base_simple_256x192_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/vitpose_base_simple_256x192_coco.pdparams) | [config](./vit_pose/vitpose_base_simple_coco_256x192.yml) | +|Vitpose_base |Top-Down| 256x192 | 78.2 | [vitpose_base_coco_256x192.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/vitpose_base_coco_256x192.pdparams) | [config](./vit_pose/vitpose_base_coco_256x192.yml) | -备注: Top-Down模型测试AP结果基于GroundTruth标注框 +备注: 1.Top-Down模型测试AP结果基于GroundTruth标注框 + 2.vitpose训练用[MAE](https://bj.bcebos.com/v1/paddledet/models/keypoint/mae_pretrain_vit_base.pdparams)做为预训练模型 MPII数据集 | 模型 | 方案| 输入尺寸 | PCKh(Mean) | PCKh(Mean@0.1) | 模型下载 | 配置文件 | @@ -284,4 +287,12 @@ python deploy/python/det_keypoint_unite_infer.py \ booktitle={CVPR}, year={2021} } + +@inproceedings{ + xu2022vitpose, + title={ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation}, + author={Yufei Xu and Jing Zhang and Qiming Zhang and Dacheng Tao}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022}, +} ``` diff --git a/configs/keypoint/README_en.md b/configs/keypoint/README_en.md index 15f659645..252f31745 100644 --- a/configs/keypoint/README_en.md +++ b/configs/keypoint/README_en.md @@ -75,8 +75,11 @@ COCO Dataset | LiteHRNet-18 | 384x288 | 69.7 | [lite_hrnet_18_384x288_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_18_384x288_coco.pdparams) | [config](./lite_hrnet/lite_hrnet_18_384x288_coco.yml) | | LiteHRNet-30 | 256x192 | 69.4 | [lite_hrnet_30_256x192_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_30_256x192_coco.pdparams) | [config](./lite_hrnet/lite_hrnet_30_256x192_coco.yml) | | LiteHRNet-30 | 384x288 | 72.5 | [lite_hrnet_30_384x288_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/lite_hrnet_30_384x288_coco.pdparams) | [config](./lite_hrnet/lite_hrnet_30_384x288_coco.yml) | +| Vitpose_base_simple | 256x192 | 77.7 | [vitpose_base_simple_256x192_coco.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/vitpose_base_simple_256x192_coco.pdparams) | [config](./vit_pose/vitpose_base_simple_coco_256x192.yml) | +| Vitpose_base | 256x192 | 78.2 | [vitpose_base_coco_256x192.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/vitpose_base_coco_256x192.pdparams) | [config](./vit_pose/vitpose_base_coco_256x192.yml) | -Note:The AP results of Top-Down models are based on bounding boxes in GroundTruth. +Note:1.The AP results of Top-Down models are based on bounding boxes in GroundTruth. + 2.Vitpose training uses [MAE](https://bj.bcebos.com/v1/paddledet/models/keypoint/mae_pretrain_vit_base.pdparams) as the pre-training model MPII Dataset | Model | Input Size | PCKh(Mean) | PCKh(Mean@0.1) | Model Download | Config File | @@ -266,4 +269,12 @@ We provide benchmarks in different runtime environments for your reference when booktitle={CVPR}, year={2021} } + +@inproceedings{ + xu2022vitpose, + title={ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation}, + author={Yufei Xu and Jing Zhang and Qiming Zhang and Dacheng Tao}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022}, +} ``` diff --git a/configs/keypoint/vit_pose/vitpose_base_coco_256x196.yml b/configs/keypoint/vit_pose/vitpose_base_coco_256x196.yml new file mode 100644 index 000000000..3e4934e3a --- /dev/null +++ b/configs/keypoint/vit_pose/vitpose_base_coco_256x196.yml @@ -0,0 +1,171 @@ +use_gpu: true +log_iter: 50 +save_dir: output +snapshot_epoch: 10 +weights: output/vitpose_base_simple_coco_256x192/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: VitPose_TopDown +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/mae_pretrain_vit_base.pdparams +VitPose_TopDown: + backbone: ViT + head: TopdownHeatmapSimpleHead + post_process: VitPosePostProcess + loss: KeyPointMSELoss + flip_test: True + +ViT: + img_size: [256, 192] + patch_size: 16 + embed_dim: 768 + depth: 12 + num_heads: 12 + ratio: 1 + mlp_ratio: 4 + qkv_bias: True + drop_path_rate: 0.3 + epsilon: 0.000001 + + +TopdownHeatmapSimpleHead: + in_channels: 768 + num_deconv_layers: 2 + num_deconv_filters: [256,256] + num_deconv_kernels: [4,4] + out_channels: 17 + shift_heatmap: False + flip_pairs: *flip_perm + extra: {final_conv_kernel: 1} + +VitPosePostProcess: + use_dark: True + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + +####optimizer +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [170, 200] + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + clip_grad_by_norm: 1.0 + optimizer: + type: AdamWDL + betas: [0.9, 0.999] + weight_decay: 0.1 + num_layers: 12 + layer_decay: 0.75 + filter_bias_and_bn: True + skip_decay_names: ['pos_embed','norm'] + set_param_lr_func: 'layerwise_lr_decay' + + + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + center_scale: 0.4 + + + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + image_thre: 0.0 + use_gt_bbox: True + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.5 + rot: 40 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + + - TopDownAffine: + trainsize: *trainsize + use_udp: true + - ToHeatmapsTopDown_UDP: + hmsize: *hmsize + sigma: 2 + + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: True + drop_last: True + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + use_udp: true + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: false diff --git a/configs/keypoint/vit_pose/vitpose_base_simple_coco_256x192.yml b/configs/keypoint/vit_pose/vitpose_base_simple_coco_256x192.yml new file mode 100644 index 000000000..2e34f2593 --- /dev/null +++ b/configs/keypoint/vit_pose/vitpose_base_simple_coco_256x192.yml @@ -0,0 +1,164 @@ +use_gpu: true +log_iter: 50 +save_dir: output +snapshot_epoch: 10 +weights: output/vitpose_base_simple_coco_256x192/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: VitPose_TopDown +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/mae_pretrain_vit_base.pdparams +VitPose_TopDown: + backbone: ViT + head: TopdownHeatmapSimpleHead + post_process: VitPosePostProcess + loss: KeyPointMSELoss + flip_test: True + +ViT: + img_size: [256, 192] + qkv_bias: True + drop_path_rate: 0.3 + epsilon: 0.000001 + + +TopdownHeatmapSimpleHead: + in_channels: 768 + num_deconv_layers: 0 + num_deconv_filters: [] + num_deconv_kernels: [] + upsample: 4 + shift_heatmap: False + flip_pairs: *flip_perm + extra: {final_conv_kernel: 3} + +VitPosePostProcess: + use_dark: True + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + +####optimizer +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [170, 200] + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + clip_grad_by_norm: 1.0 + optimizer: + type: AdamWDL + betas: [0.9, 0.999] + weight_decay: 0.1 + num_layers: 12 + layer_decay: 0.75 + filter_bias_and_bn: True + skip_decay_names: ['pos_embed','norm'] + set_param_lr_func: 'layerwise_lr_decay' + + + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + center_scale: 0.4 + + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + image_thre: 0.0 + use_gt_bbox: True + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.5 + rot: 40 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + + - TopDownAffine: + trainsize: *trainsize + use_udp: true + - ToHeatmapsTopDown_UDP: + hmsize: *hmsize + sigma: 2 + + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: True + drop_last: True + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + use_udp: true + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: false diff --git a/ppdet/data/source/keypoint_coco.py b/ppdet/data/source/keypoint_coco.py index 11ecea538..6e072dc6e 100644 --- a/ppdet/data/source/keypoint_coco.py +++ b/ppdet/data/source/keypoint_coco.py @@ -491,7 +491,8 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): bbox_file=None, use_gt_bbox=True, pixel_std=200, - image_thre=0.0): + image_thre=0.0, + center_scale=None): super().__init__(dataset_dir, image_dir, anno_path, num_joints, transform) @@ -500,6 +501,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): self.trainsize = trainsize self.pixel_std = pixel_std self.image_thre = image_thre + self.center_scale = center_scale self.dataset_name = 'coco' def parse_dataset(self): @@ -574,6 +576,9 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): center[1] = y + h * 0.5 aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1] + if self.center_scale is not None and np.random.rand() < 0.3: + center += self.center_scale * (np.random.rand(2) - 0.5) * [w, h] + if w > aspect_ratio * h: h = w * 1.0 / aspect_ratio elif w < aspect_ratio * h: diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 8899e5c0b..4c6c5ed0a 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -25,6 +25,7 @@ from . import ttfnet from . import s2anet from . import keypoint_hrhrnet from . import keypoint_hrnet +from . import keypoint_vitpose from . import jde from . import deepsort from . import fairmot @@ -55,6 +56,7 @@ from .ttfnet import * from .s2anet import * from .keypoint_hrhrnet import * from .keypoint_hrnet import * +from .keypoint_vitpose import * from .jde import * from .deepsort import * from .fairmot import * diff --git a/ppdet/modeling/architectures/keypoint_vitpose.py b/ppdet/modeling/architectures/keypoint_vitpose.py new file mode 100644 index 000000000..b00226a83 --- /dev/null +++ b/ppdet/modeling/architectures/keypoint_vitpose.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +import math +import cv2 +from ppdet.core.workspace import register, create, serializable +from .meta_arch import BaseArch +from ..keypoint_utils import transform_preds +from .. import layers as L + +__all__ = ['VitPose_TopDown', 'VitPosePostProcess'] + + +@register +class VitPose_TopDown(BaseArch): + __category__ = 'architecture' + __inject__ = ['loss'] + + def __init__(self, backbone, head, loss, post_process, flip_test): + """ + VitPose network, see https://arxiv.org/pdf/2204.12484v2.pdf + + Args: + backbone (nn.Layer): backbone instance + post_process (object): `HRNetPostProcess` instance + + """ + super(VitPose_TopDown, self).__init__() + self.backbone = backbone + self.head = head + self.loss = loss + self.post_process = post_process + self.flip_test = flip_test + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + #head + head = create(cfg['head']) + #post_process + post_process = create(cfg['post_process']) + + return { + 'backbone': backbone, + 'head': head, + 'post_process': post_process + } + + def _forward_train(self): + + feats = self.backbone.forward_features(self.inputs['image']) + vitpost_output = self.head(feats) + return self.loss(vitpost_output, self.inputs) + + def _forward_test(self): + + feats = self.backbone.forward_features(self.inputs['image']) + output_heatmap = self.head(feats) + + if self.flip_test: + img_flipped = self.inputs['image'].flip(3) + features_flipped = self.backbone.forward_features(img_flipped) + output_flipped_heatmap = self.head.inference_model(features_flipped, + self.flip_test) + + output_heatmap = (output_heatmap + output_flipped_heatmap) * 0.5 + + imshape = (self.inputs['im_shape'].numpy() + )[:, ::-1] if 'im_shape' in self.inputs else None + center = self.inputs['center'].numpy( + ) if 'center' in self.inputs else np.round(imshape / 2.) + scale = self.inputs['scale'].numpy( + ) if 'scale' in self.inputs else imshape / 200. + + result = self.post_process(output_heatmap.cpu().numpy(), center, scale) + + return result + + def get_loss(self): + return self._forward_train() + + def get_pred(self): + res_lst = self._forward_test() + outputs = {'keypoint': res_lst} + return outputs + + +@register +@serializable +class VitPosePostProcess(object): + def __init__(self, use_dark=False): + self.use_dark = use_dark + + def get_max_preds(self, heatmaps): + '''get predictions from score maps + + Args: + heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) + + Returns: + preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords + maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints + ''' + assert isinstance(heatmaps, + np.ndarray), 'heatmaps should be numpy.ndarray' + assert heatmaps.ndim == 4, 'batch_images should be 4-ndim' + + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + width = heatmaps.shape[3] + heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1)) + idx = np.argmax(heatmaps_reshaped, 2) + maxvals = np.amax(heatmaps_reshaped, 2) + + maxvals = maxvals.reshape((batch_size, num_joints, 1)) + idx = idx.reshape((batch_size, num_joints, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + + preds[:, :, 0] = (preds[:, :, 0]) % width + preds[:, :, 1] = np.floor((preds[:, :, 1]) // width) + + pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) + pred_mask = pred_mask.astype(np.float32) + + preds *= pred_mask + + return preds, maxvals + + def post_datk_udp(self, coords, batch_heatmaps, kernel=3): + """DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The + Devil is in the Details: Delving into Unbiased Data Processing for Human + Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + + Note: + - batch size: B + - num keypoints: K + - num persons: N + - height of heatmaps: H + - width of heatmaps: W + + B=1 for bottom_up paradigm where all persons share the same heatmap. + B=N for top_down paradigm where each person has its own heatmaps. + + Args: + coords (np.ndarray[N, K, 2]): Initial coordinates of human pose. + batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps + kernel (int): Gaussian kernel size (K) for modulation. + + Returns: + np.ndarray([N, K, 2]): Refined coordinates. + """ + if not isinstance(batch_heatmaps, np.ndarray): + batch_heatmaps = batch_heatmaps.cpu().numpy() + B, K, H, W = batch_heatmaps.shape + N = coords.shape[0] + assert (B == 1 or B == N) + for heatmaps in batch_heatmaps: + for heatmap in heatmaps: + cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap) + np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps) + np.log(batch_heatmaps, batch_heatmaps) + + batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), + (1, 1)), + mode='edge').flatten() + + index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K) + index = index.astype(int).reshape(-1, 1) + i_ = batch_heatmaps_pad[index] + ix1 = batch_heatmaps_pad[index + 1] + iy1 = batch_heatmaps_pad[index + W + 2] + ix1y1 = batch_heatmaps_pad[index + W + 3] + ix1_y1_ = batch_heatmaps_pad[index - W - 3] + ix1_ = batch_heatmaps_pad[index - 1] + iy1_ = batch_heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(N, K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(N, K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + coords -= np.einsum('ijmn,ijnk->ijmk', hessian, derivative).squeeze() + return coords + + def transform_preds_udp(self, + coords, + center, + scale, + output_size, + use_udp=True): + """Get final keypoint predictions from heatmaps and apply scaling and + translation to map them back to the image. + + Note: + num_keypoints: K + + Args: + coords (np.ndarray[K, ndims]): + + * If ndims=2, corrds are predicted keypoint location. + * If ndims=4, corrds are composed of (x, y, scores, tags) + * If ndims=5, corrds are composed of (x, y, scores, tags, + flipped_tags) + + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + use_udp (bool): Use unbiased data processing + + Returns: + np.ndarray: Predicted coordinates in the images. + """ + + assert coords.shape[1] in (2, 4, 5) + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + + # Recover the scale which is normalized by a factor of 200. + scale = scale * 200.0 + + if use_udp: + scale_x = scale[0] / (output_size[0] - 1.0) + scale_y = scale[1] / (output_size[1] - 1.0) + else: + scale_x = scale[0] / output_size[0] + scale_y = scale[1] / output_size[1] + + target_coords = np.ones_like(coords) + target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[ + 0] * 0.5 + target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[ + 1] * 0.5 + + return target_coords + + def get_final_preds(self, heatmaps, center, scale, kernelsize=11): + """the highest heatvalue location with a quarter offset in the + direction from the highest response to the second highest response. + + Args: + heatmaps (numpy.ndarray): The predicted heatmaps + center (numpy.ndarray): The boxes center + scale (numpy.ndarray): The scale factor + + Returns: + preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords + maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints + """ + coords, maxvals = self.get_max_preds(heatmaps) + + N, K, H, W = heatmaps.shape + + if self.use_dark: + coords = self.post_datk_udp(coords, heatmaps, kernelsize) + preds = coords.copy() + # Transform back to the image + for i in range(N): + preds[i] = self.transform_preds_udp(preds[i], center[i], + scale[i], [W, H]) + else: + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + hm = heatmaps[n][p] + px = int(math.floor(coords[n][p][0] + 0.5)) + py = int(math.floor(coords[n][p][1] + 0.5)) + if 1 < px < W - 1 and 1 < py < H - 1: + diff = np.array([ + hm[py][px + 1] - hm[py][px - 1], + hm[py + 1][px] - hm[py - 1][px] + ]) + coords[n][p] += np.sign(diff) * .25 + preds = coords.copy() + + # Transform back + for i in range(coords.shape[0]): + preds[i] = transform_preds(coords[i], center[i], scale[i], + [W, H]) + + return preds, maxvals + + def __call__(self, output, center, scale): + preds, maxvals = self.get_final_preds(output, center, scale) + outputs = [[ + np.concatenate( + (preds, maxvals), axis=-1), np.mean( + maxvals, axis=1) + ]] + return outputs \ No newline at end of file diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index f8b183e27..a20189c94 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -62,4 +62,5 @@ from .vision_transformer import * from .mobileone import * from .trans_encoder import * from .focalnet import * -from .vit_mae import * +from .vitpose import * +from .vit_mae import * \ No newline at end of file diff --git a/ppdet/modeling/backbones/vitpose.py b/ppdet/modeling/backbones/vitpose.py new file mode 100644 index 000000000..23e00be1e --- /dev/null +++ b/ppdet/modeling/backbones/vitpose.py @@ -0,0 +1,320 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# reference: https://arxiv.org/abs/2010.11929 + +from collections.abc import Callable + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +from ppdet.core.workspace import register, serializable + +trunc_normal_ = TruncatedNormal(std=.02) + + +def to_2tuple(x): + if isinstance(x, (list, tuple)): + return x + return tuple([x] * 2) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1.0 - drop_prob).astype(x.dtype) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + N, C = x.shape[1:] + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-5): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm1 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm2 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + num_patches = (img_size[1] // patch_size[1]) * ( + img_size[0] // patch_size[0]) * (ratio**2) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2D( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=(patch_size[0] // ratio), + padding=(4 + 2 * (ratio // 2 - 1), 4 + 2 * (ratio // 2 - 1))) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + x = self.proj(x) + return x + + +@register +@serializable +class ViT(nn.Layer): + """ Vision Transformer with support for patch input + + This module is different from ppdet's VisionTransformer (from ppdet/modeling/backbones/visio_transformer.py), + the main differences are: + 1.the module PatchEmbed.proj has padding set,padding=(4 + 2 * (ratio // 2 - 1), 4 + 2 * (ratio // 2 - 1), + VisionTransformer dose not + 2.Attention module qkv is standard.but VisionTransformer provide more options + 3.MLP module only one Dropout,and VisionTransformer twice; + 4.VisionTransformer provide fpn layer,but the module does not. + + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer='nn.LayerNorm', + epsilon=1e-5, + ratio=1, + pretrained=None, + **kwargs): + super().__init__() + + self.pretrained = pretrained + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ratio=ratio) + num_patches = self.patch_embed.num_patches + + self.pos_embed = self.create_parameter( + shape=(1, num_patches + 1, embed_dim), + default_initializer=trunc_normal_) + self.add_parameter("pos_embed", self.pos_embed) + + dpr = np.linspace(0, drop_path_rate, depth, dtype='float32') + + self.blocks = nn.LayerList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + epsilon=epsilon) for i in range(depth) + ]) + + self.last_norm = eval(norm_layer)(embed_dim, epsilon=epsilon) + trunc_normal_(self.pos_embed) + self._init_weights() + + def _init_weights(self): + pretrained = self.pretrained + + if pretrained: + + if 'http' in pretrained: #URL + path = paddle.utils.download.get_weights_path_from_url( + pretrained) + else: #model in local path + path = pretrained + + load_state_dict = paddle.load(path) + self.set_state_dict(load_state_dict) + print("Load load_state_dict:", path) + + def forward_features(self, x): + + B = paddle.shape(x)[0] + x = self.patch_embed(x) + B, D, Hp, Wp = x.shape + x = x.flatten(2).transpose([0, 2, 1]) + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + x = blk(x) + + x = self.last_norm(x) + xp = paddle.reshape( + paddle.transpose( + x, perm=[0, 2, 1]), shape=[B, -1, Hp, Wp]) + + return xp diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 07df124cd..44a9fa85d 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -39,6 +39,7 @@ from . import yolof_head from . import ppyoloe_contrast_head from . import centertrack_head from . import sparse_roi_head +from . import vitpose_head from .bbox_head import * from .mask_head import * @@ -68,3 +69,4 @@ from .ppyoloe_contrast_head import * from .centertrack_head import * from .sparse_roi_head import * from .petr_head import * +from .vitpose_head import * \ No newline at end of file diff --git a/ppdet/modeling/heads/vitpose_head.py b/ppdet/modeling/heads/vitpose_head.py new file mode 100644 index 000000000..43908ed57 --- /dev/null +++ b/ppdet/modeling/heads/vitpose_head.py @@ -0,0 +1,278 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ppdet.modeling.keypoint_utils import resize, flip_back +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +from ppdet.modeling.layers import ConvTranspose2d, BatchNorm2d + +trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal(std=0.001) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + +__all__ = ['TopdownHeatmapSimpleHead'] + + +@register +class TopdownHeatmapSimpleHead(nn.Layer): + def __init__(self, + in_channels=768, + out_channels=17, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None, + in_index=0, + input_transform=None, + align_corners=False, + upsample=0, + flip_pairs=None, + shift_heatmap=False, + target_type='GaussianHeatmap'): + super(TopdownHeatmapSimpleHead, self).__init__() + + self.in_channels = in_channels + self.upsample = upsample + self.flip_pairs = flip_pairs + self.shift_heatmap = shift_heatmap + self.target_type = target_type + + self._init_inputs(in_channels, in_index, input_transform) + self.in_index = in_index + self.align_corners = align_corners + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + conv_channels = num_deconv_filters[ + -1] if num_deconv_layers > 0 else self.in_channels + + layers = [] + if extra is not None: + num_conv_layers = extra.get('num_conv_layers', 0) + num_conv_kernels = extra.get('num_conv_kernels', + [1] * num_conv_layers) + + for i in range(num_conv_layers): + layers.append( + nn.Conv2D( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=num_conv_kernels[i], + stride=1, + padding=(num_conv_kernels[i] - 1) // 2)) + layers.append(nn.BatchNorm2D(conv_channels)) + layers.append(nn.ReLU()) + + layers.append( + nn.Conv2D( + in_channels=conv_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=(padding, padding))) + + if len(layers) > 1: + self.final_layer = nn.Sequential(*layers) + else: + self.final_layer = layers[0] + + self.init_weights() + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + """ + if not isinstance(inputs, list): + if not isinstance(inputs, list): + + if self.upsample > 0: + inputs = resize( + input=F.relu(inputs), + scale_factor=self.upsample, + mode='bilinear', + align_corners=self.align_corners) + return inputs + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = paddle.concat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def forward(self, x): + """Forward function.""" + x = self._transform_inputs(x) + x = self.deconv_layers(x) + x = self.final_layer(x) + + return x + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_heatmap (np.ndarray): Output heatmaps. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + + if flip_pairs is not None: + output_heatmap = flip_back( + output, self.flip_pairs, target_type=self.target_type) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.shift_heatmap: + output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] + else: + output_heatmap = output + return output_heatmap + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + ConvTranspose2d( + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2D(planes)) + layers.append(nn.ReLU()) + self.in_channels = planes + + return nn.Sequential(*layers) + + def init_weights(self): + """Initialize model weights.""" + if not isinstance(self.deconv_layers, nn.Identity): + + for m in self.deconv_layers: + if isinstance(m, nn.BatchNorm2D): + ones_(m.weight) + ones_(m.bias) + if not isinstance(self.final_layer, nn.Conv2D): + + for m in self.final_layer: + if isinstance(m, nn.Conv2D): + normal_(m.weight) + zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2D): + ones_(m.weight) + ones_(m.bias) + else: + normal_(self.final_layer.weight) + zeros_(self.final_layer.bias) diff --git a/ppdet/modeling/keypoint_utils.py b/ppdet/modeling/keypoint_utils.py index d5cbeb3ba..377f1d75c 100644 --- a/ppdet/modeling/keypoint_utils.py +++ b/ppdet/modeling/keypoint_utils.py @@ -17,6 +17,7 @@ this code is based on https://github.com/open-mmlab/mmpose import cv2 import numpy as np +import paddle.nn.functional as F def get_affine_mat_kernel(h, w, s, inv=False): @@ -340,3 +341,63 @@ def soft_oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None): keep = keep[:keep_cnt] return keep + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 and + input_w > 1) and (output_h - 1) % (input_h - 1) and + (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +def flip_back(output_flipped, flip_pairs, target_type='GaussianHeatmap'): + """Flip the flipped heatmaps back to the original form. + Note: + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + Args: + output_flipped (np.ndarray[N, K, H, W]): The output heatmaps obtained + from the flipped images. + flip_pairs (list[tuple()): Pairs of keypoints which are mirrored + (for example, left ear -- right ear). + target_type (str): GaussianHeatmap or CombinedTarget + Returns: + np.ndarray: heatmaps that flipped back to the original image + """ + assert len(output_flipped.shape) == 4, \ + 'output_flipped should be [batch_size, num_keypoints, height, width]' + shape_ori = output_flipped.shape + channels = 1 + if target_type.lower() == 'CombinedTarget'.lower(): + channels = 3 + output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...] + output_flipped = output_flipped.reshape((shape_ori[0], -1, channels, + shape_ori[2], shape_ori[3])) + output_flipped_back = output_flipped.clone() + + # Swap left-right parts + for left, right in flip_pairs: + output_flipped_back[:, left, ...] = output_flipped[:, right, ...] + output_flipped_back[:, right, ...] = output_flipped[:, left, ...] + output_flipped_back = output_flipped_back.reshape(shape_ori) + # Flip horizontally + output_flipped_back = output_flipped_back[..., ::-1] + return output_flipped_back diff --git a/ppdet/optimizer/adamw.py b/ppdet/optimizer/adamw.py index 6ecf676d6..12ab619a3 100644 --- a/ppdet/optimizer/adamw.py +++ b/ppdet/optimizer/adamw.py @@ -50,7 +50,7 @@ def layerwise_lr_decay(decay_rate, name_dict, n_layers, param): layer = int(static_name[idx:].split('.')[1]) ratio = decay_rate**(n_layers - layer) - elif 'cls_token' in static_name or 'patch_embed' in static_name: + elif 'cls_token' in static_name or 'patch_embed' in static_name or 'pos_embed' in static_name: ratio = decay_rate**(n_layers + 1) if IS_PADDLE_LATER_2_4: -- GitLab