diff --git a/.gitignore b/.gitignore index 6a98a38b72ef9a59a2fdab266697661cdd1fa136..4b6a6e8246385c676e00a412f6030ec4100d090f 100644 --- a/.gitignore +++ b/.gitignore @@ -18,9 +18,9 @@ __pycache__/ # Distribution / packaging /bin/ -/build/ +*build/ /develop-eggs/ -/dist/ +*dist/ /eggs/ /lib/ /lib64/ @@ -30,7 +30,7 @@ __pycache__/ /parts/ /sdist/ /var/ -/*.egg-info/ +*.egg-info/ /.installed.cfg /*.egg /.eggs diff --git a/configs/keypoint/README.md b/configs/keypoint/README.md index a1e70b2e74035887fdf56589566e04965d039810..f81b2fabbd4a27c5bb7a56fca7abce34660af556 100644 --- a/configs/keypoint/README.md +++ b/configs/keypoint/README.md @@ -56,8 +56,10 @@ PaddleDetection 中的关键点检测部分紧跟最先进的算法,包括 Top ## 模型库 COCO数据集 + | 模型 | 方案 |输入尺寸 | AP(coco val) | 模型下载 | 配置文件 | | :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------| ------- | +| PETR_Res50 |One-Stage| 512 | 65.5 | [petr_res50.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/petr_resnet50_16x2_coco.pdparams) | [config](./petr/petr_resnet50_16x2_coco.yml) | | HigherHRNet-w32 |Bottom-Up| 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) | | HigherHRNet-w32 | Bottom-Up| 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) | | HigherHRNet-w32+SWAHR |Bottom-Up| 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) | diff --git a/configs/keypoint/README_en.md b/configs/keypoint/README_en.md index a927d80cf51b3a75443e7b321f2fd1f7ffe910ed..64ffc39d61c63a3893b079e255facaec3620aeb6 100644 --- a/configs/keypoint/README_en.md +++ b/configs/keypoint/README_en.md @@ -62,6 +62,7 @@ At the same time, PaddleDetection provides a self-developed real-time keypoint d COCO Dataset | Model | Input Size | AP(coco val) | Model Download | Config File | | :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- | +| PETR_Res50 |One-Stage| 512 | 65.5 | [petr_res50.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/petr_resnet50_16x2_coco.pdparams) | [config](./petr/petr_resnet50_16x2_coco.yml) | | HigherHRNet-w32 | 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) | | HigherHRNet-w32 | 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) | | HigherHRNet-w32+SWAHR | 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) | diff --git a/configs/keypoint/petr/petr_resnet50_16x2_coco.yml b/configs/keypoint/petr/petr_resnet50_16x2_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..d6415ad3b8b7334e7c19517ab7ee2f99e87dbc0c --- /dev/null +++ b/configs/keypoint/petr/petr_resnet50_16x2_coco.yml @@ -0,0 +1,255 @@ +use_gpu: true +log_iter: 50 +save_dir: output +snapshot_epoch: 1 +weights: output/petr_resnet50_16x2_coco/model_final +epoch: 100 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: COCO +num_classes: 1 +trainsize: &trainsize 512 +flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] +find_unused_parameters: False + +#####model +architecture: PETR +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/PETR_pretrained.pdparams + +PETR: + backbone: + name: ResNet + depth: 50 + variant: b + norm_type: bn + freeze_norm: True + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + neck: + name: ChannelMapper + in_channels: [512, 1024, 2048] + kernel_size: 1 + out_channels: 256 + norm_type: "gn" + norm_groups: 32 + act: None + num_outs: 4 + bbox_head: + name: PETRHead + num_query: 300 + num_classes: 1 # only person + in_channels: 2048 + sync_cls_avg_factor: true + with_kpt_refine: true + transformer: + name: PETRTransformer + as_two_stage: true + encoder: + name: TransformerEncoder + encoder_layer: + name: TransformerEncoderLayer + d_model: 256 + attn: + name: MSDeformableAttention + embed_dim: 256 + num_heads: 8 + num_levels: 4 + num_points: 4 + dim_feedforward: 1024 + dropout: 0.1 + num_layers: 6 + decoder: + name: PETR_TransformerDecoder + num_layers: 3 + return_intermediate: true + decoder_layer: + name: PETR_TransformerDecoderLayer + d_model: 256 + dim_feedforward: 1024 + dropout: 0.1 + self_attn: + name: MultiHeadAttention + embed_dim: 256 + num_heads: 8 + dropout: 0.1 + cross_attn: + name: MultiScaleDeformablePoseAttention + embed_dims: 256 + num_heads: 8 + num_levels: 4 + num_points: 17 + hm_encoder: + name: TransformerEncoder + encoder_layer: + name: TransformerEncoderLayer + d_model: 256 + attn: + name: MSDeformableAttention + embed_dim: 256 + num_heads: 8 + num_levels: 1 + num_points: 4 + dim_feedforward: 1024 + dropout: 0.1 + num_layers: 1 + refine_decoder: + name: PETR_DeformableDetrTransformerDecoder + num_layers: 2 + return_intermediate: true + decoder_layer: + name: PETR_TransformerDecoderLayer + d_model: 256 + dim_feedforward: 1024 + dropout: 0.1 + self_attn: + name: MultiHeadAttention + embed_dim: 256 + num_heads: 8 + dropout: 0.1 + cross_attn: + name: MSDeformableAttention + embed_dim: 256 + num_levels: 4 + positional_encoding: + name: PositionEmbedding + num_pos_feats: 128 + normalize: true + offset: -0.5 + loss_cls: + name: Weighted_FocalLoss + use_sigmoid: true + gamma: 2.0 + alpha: 0.25 + loss_weight: 2.0 + reduction: "mean" + loss_kpt: + name: L1Loss + loss_weight: 70.0 + loss_kpt_rpn: + name: L1Loss + loss_weight: 70.0 + loss_oks: + name: OKSLoss + loss_weight: 2.0 + loss_hm: + name: CenterFocalLoss + loss_weight: 4.0 + loss_kpt_refine: + name: L1Loss + loss_weight: 80.0 + loss_oks_refine: + name: OKSLoss + loss_weight: 3.0 + assigner: + name: PoseHungarianAssigner + cls_cost: + name: FocalLossCost + weight: 2.0 + kpt_cost: + name: KptL1Cost + weight: 70.0 + oks_cost: + name: OksCost + weight: 7.0 + +#####optimizer +LearningRate: + base_lr: 0.0002 + schedulers: + - !PiecewiseDecay + milestones: [80] + gamma: 0.1 + use_warmup: false + # - !LinearWarmup + # start_factor: 0.001 + # steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + optimizer: + type: AdamW + regularizer: + factor: 0.0001 + type: L2 + + +#####data +TrainDataset: + !KeypointBottomUpCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + return_mask: false + +EvalDataset: + !KeypointBottomUpCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + test_mode: true + return_mask: false + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - Decode: {} + - PhotoMetricDistortion: + brightness_delta: 32 + contrast_range: [0.5, 1.5] + saturation_range: [0.5, 1.5] + hue_delta: 18 + - KeyPointFlip: + flip_prob: 0.5 + flip_permutation: *flip_perm + - RandomAffine: + max_degree: 30 + scale: [1.0, 1.0] + max_shift: 0. + trainsize: -1 + - RandomSelect: { transforms1: [ RandomShortSideRangeResize: { scales: [[400, 1400], [1400, 1400]]} ], + transforms2: [ + RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] }, + RandomSizeCrop: { min_size: 384, max_size: 600}, + RandomShortSideRangeResize: { scales: [[400, 1400], [1400, 1400]]} ]} + batch_transforms: + - NormalizeImage: {mean: *global_mean, std: *global_std, is_scale: True} + - PadGT: {pad_img: True, minimum_gtnum: 1} + - Permute: {} + batch_size: 2 + shuffle: true + drop_last: true + use_shared_memory: true + collate_batch: true + +EvalReader: + sample_transforms: + - PETR_Resize: {img_scale: [[800, 1333]], keep_ratio: True} + # - MultiscaleTestResize: {origin_target_size: [[800, 1333]], use_flip: false} + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + +TestReader: + sample_transforms: + - Decode: {} + - EvalAffine: + size: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/configs/pose3d/tinypose3d_human36M.yml b/configs/pose3d/tinypose3d_human36M.yml new file mode 100644 index 0000000000000000000000000000000000000000..a3ccdbbbd588013cb87418a269f0abc9dde17dea --- /dev/null +++ b/configs/pose3d/tinypose3d_human36M.yml @@ -0,0 +1,123 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 1 +weights: output/tinypose3d_human36M/model_final +epoch: 220 +num_joints: &num_joints 24 +pixel_std: &pixel_std 200 +metric: Pose3DEval +num_classes: 1 +train_height: &train_height 128 +train_width: &train_width 128 +trainsize: &trainsize [*train_width, *train_height] + +#####model +architecture: TinyPose3DHRNet +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.pdparams + +TinyPose3DHRNet: + backbone: LiteHRNet + post_process: HR3DNetPostProcess + fc_channel: 1024 + num_joints: *num_joints + width: &width 40 + loss: Pose3DLoss + +LiteHRNet: + network_type: wider_naive + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +Pose3DLoss: + weight_3d: 1.0 + weight_2d: 0.0 + +#####optimizer +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + milestones: [17, 21] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.01 + steps: 1000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !Pose3DDataset + dataset_dir: Human3.6M + image_dirs: ["Images"] + anno_list: ['Human3.6m_train.json'] + num_joints: *num_joints + test_mode: False + +EvalDataset: + !Pose3DDataset + dataset_dir: Human3.6M + image_dirs: ["Images"] + anno_list: ['Human3.6m_valid.json'] + num_joints: *num_joints + test_mode: True + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - SinglePoseAffine: + trainsize: *trainsize + rotate: [0.5, 30] #[prob, rotate range] + scale: [0.5, 0.25] #[prob, scale range] + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 128 + shuffle: true + drop_last: true + +EvalReader: + sample_transforms: + - SinglePoseAffine: + trainsize: *trainsize + rotate: [0., 30] + scale: [0., 0.25] + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 128 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 + fuse_normalize: false diff --git a/docs/tutorials/data/PrepareKeypointDataSet.md b/docs/tutorials/data/PrepareKeypointDataSet.md index 4efa90b8d2b2a70430c13feccffe0342ce94e5fd..27d844c03482047dfa47db1985b10fecef9ee74b 100644 --- a/docs/tutorials/data/PrepareKeypointDataSet.md +++ b/docs/tutorials/data/PrepareKeypointDataSet.md @@ -82,7 +82,7 @@ MPII keypoint indexes: ``` { 'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], - 'joints': [ + 'gt_joints': [ [-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], diff --git a/docs/tutorials/data/PrepareKeypointDataSet_en.md b/docs/tutorials/data/PrepareKeypointDataSet_en.md index 80272910cee355e28d6aa219e30bc98de599bbd0..6ed566d171a9fa6888ff2caaa3a4df521a97ebfa 100644 --- a/docs/tutorials/data/PrepareKeypointDataSet_en.md +++ b/docs/tutorials/data/PrepareKeypointDataSet_en.md @@ -82,7 +82,7 @@ The following example takes a parsed annotation information to illustrate the co ``` { 'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], - 'joints': [ + 'gt_joints': [ [-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], diff --git a/ppdet/data/source/keypoint_coco.py b/ppdet/data/source/keypoint_coco.py index 45eb9a91d7381d649ead2bb70954eae7acac76d0..11ecea538404bf498c66a90f4e8293824edbf317 100644 --- a/ppdet/data/source/keypoint_coco.py +++ b/ppdet/data/source/keypoint_coco.py @@ -80,7 +80,8 @@ class KeypointBottomUpBaseDataset(DetDataset): records = copy.deepcopy(self._get_imganno(idx)) records['image'] = cv2.imread(records['image_file']) records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB) - records['mask'] = (records['mask'] + 0).astype('uint8') + if 'mask' in records: + records['mask'] = (records['mask'] + 0).astype('uint8') records = self.transform(records) return records @@ -135,24 +136,37 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): num_joints, transform=[], shard=[0, 1], - test_mode=False): + test_mode=False, + return_mask=True, + return_bbox=True, + return_area=True, + return_class=True): super().__init__(dataset_dir, image_dir, anno_path, num_joints, transform, shard, test_mode) self.ann_file = os.path.join(dataset_dir, anno_path) self.shard = shard self.test_mode = test_mode + self.return_mask = return_mask + self.return_bbox = return_bbox + self.return_area = return_area + self.return_class = return_class def parse_dataset(self): self.coco = COCO(self.ann_file) self.img_ids = self.coco.getImgIds() if not self.test_mode: - self.img_ids = [ - img_id for img_id in self.img_ids - if len(self.coco.getAnnIds( - imgIds=img_id, iscrowd=None)) > 0 - ] + self.img_ids_tmp = [] + for img_id in self.img_ids: + ann_ids = self.coco.getAnnIds(imgIds=img_id) + anno = self.coco.loadAnns(ann_ids) + anno = [obj for obj in anno if obj['iscrowd'] == 0] + if len(anno) == 0: + continue + self.img_ids_tmp.append(img_id) + self.img_ids = self.img_ids_tmp + blocknum = int(len(self.img_ids) / self.shard[1]) self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * ( self.shard[0] + 1))] @@ -199,21 +213,31 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): ann_ids = coco.getAnnIds(imgIds=img_id) anno = coco.loadAnns(ann_ids) - mask = self._get_mask(anno, idx) anno = [ obj for obj in anno - if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0 + if obj['iscrowd'] == 0 and obj['num_keypoints'] > 0 ] + db_rec = {} joints, orgsize = self._get_joints(anno, idx) + db_rec['gt_joints'] = joints + db_rec['im_shape'] = orgsize + + if self.return_bbox: + db_rec['gt_bbox'] = self._get_bboxs(anno, idx) + + if self.return_class: + db_rec['gt_class'] = self._get_labels(anno, idx) + + if self.return_area: + db_rec['gt_areas'] = self._get_areas(anno, idx) + + if self.return_mask: + db_rec['mask'] = self._get_mask(anno, idx) - db_rec = {} db_rec['im_id'] = img_id db_rec['image_file'] = os.path.join(self.img_prefix, self.id2name[img_id]) - db_rec['mask'] = mask - db_rec['joints'] = joints - db_rec['im_shape'] = orgsize return db_rec @@ -229,12 +253,41 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): np.array(obj['keypoints']).reshape([-1, 3]) img_info = self.coco.loadImgs(self.img_ids[idx])[0] - joints[..., 0] /= img_info['width'] - joints[..., 1] /= img_info['height'] - orgsize = np.array([img_info['height'], img_info['width']]) + orgsize = np.array([img_info['height'], img_info['width'], 1]) return joints, orgsize + def _get_bboxs(self, anno, idx): + num_people = len(anno) + gt_bboxes = np.zeros((num_people, 4), dtype=np.float32) + + for idx, obj in enumerate(anno): + if 'bbox' in obj: + gt_bboxes[idx, :] = obj['bbox'] + + gt_bboxes[:, 2] += gt_bboxes[:, 0] + gt_bboxes[:, 3] += gt_bboxes[:, 1] + return gt_bboxes + + def _get_labels(self, anno, idx): + num_people = len(anno) + gt_labels = np.zeros((num_people, 1), dtype=np.float32) + + for idx, obj in enumerate(anno): + if 'category_id' in obj: + catid = obj['category_id'] + gt_labels[idx, 0] = self.catid2clsid[catid] + return gt_labels + + def _get_areas(self, anno, idx): + num_people = len(anno) + gt_areas = np.zeros((num_people, ), dtype=np.float32) + + for idx, obj in enumerate(anno): + if 'area' in obj: + gt_areas[idx, ] = obj['area'] + return gt_areas + def _get_mask(self, anno, idx): """Get ignore masks to mask out losses.""" coco = self.coco @@ -506,7 +559,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): 'image_file': os.path.join(self.img_prefix, file_name), 'center': center, 'scale': scale, - 'joints': joints, + 'gt_joints': joints, 'joints_vis': joints_vis, 'im_id': im_id, }) @@ -570,7 +623,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): 'center': center, 'scale': scale, 'score': score, - 'joints': joints, + 'gt_joints': joints, 'joints_vis': joints_vis, }) @@ -647,8 +700,8 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): (self.ann_info['num_joints'], 3), dtype=np.float32) joints_vis = np.zeros( (self.ann_info['num_joints'], 3), dtype=np.float32) - if 'joints' in a: - joints_ = np.array(a['joints']) + if 'gt_joints' in a: + joints_ = np.array(a['gt_joints']) joints_[:, 0:2] = joints_[:, 0:2] - 1 joints_vis_ = np.array(a['joints_vis']) assert len(joints_) == self.ann_info[ @@ -664,7 +717,7 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): 'im_id': im_id, 'center': c, 'scale': s, - 'joints': joints, + 'gt_joints': joints, 'joints_vis': joints_vis }) print("number length: {}".format(len(gt_db))) diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 92c211ee415dac977da211217bd61a5db9857153..2637db43d217e5b9bcbc7900f396f03bf4f5319e 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -1102,13 +1102,115 @@ class PadGT(BaseOperator): 1 means bbox, 0 means no bbox. """ - def __init__(self, return_gt_mask=True): + def __init__(self, return_gt_mask=True, pad_img=False, minimum_gtnum=0): super(PadGT, self).__init__() self.return_gt_mask = return_gt_mask + self.pad_img = pad_img + self.minimum_gtnum = minimum_gtnum + + def _impad(self, img: np.ndarray, + *, + shape = None, + padding = None, + pad_val = 0, + padding_mode = 'constant') -> np.ndarray: + """Pad the given image to a certain shape or pad on all sides with + specified padding mode and padding value. + + Args: + img (ndarray): Image to be padded. + shape (tuple[int]): Expected padding shape (h, w). Default: None. + padding (int or tuple[int]): Padding on each border. If a single int is + provided this is used to pad all borders. If tuple of length 2 is + provided this is the padding on left/right and top/bottom + respectively. If a tuple of length 4 is provided this is the + padding for the left, top, right and bottom borders respectively. + Default: None. Note that `shape` and `padding` can not be both + set. + pad_val (Number | Sequence[Number]): Values to be filled in padding + areas when padding_mode is 'constant'. Default: 0. + padding_mode (str): Type of padding. Should be: constant, edge, + reflect or symmetric. Default: constant. + - constant: pads with a constant value, this value is specified + with pad_val. + - edge: pads with the last value at the edge of the image. + - reflect: pads with reflection of image without repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with 2 + elements on both sides in reflect mode will result in + [3, 2, 1, 2, 3, 4, 3, 2]. + - symmetric: pads with reflection of image repeating the last value + on the edge. For example, padding [1, 2, 3, 4] with 2 elements on + both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + ndarray: The padded image. + """ + + assert (shape is not None) ^ (padding is not None) + if shape is not None: + width = max(shape[1] - img.shape[1], 0) + height = max(shape[0] - img.shape[0], 0) + padding = (0, 0, int(width), int(height)) + + # check pad_val + import numbers + if isinstance(pad_val, tuple): + assert len(pad_val) == img.shape[-1] + elif not isinstance(pad_val, numbers.Number): + raise TypeError('pad_val must be a int or a tuple. ' + f'But received {type(pad_val)}') + + # check padding + if isinstance(padding, tuple) and len(padding) in [2, 4]: + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + elif isinstance(padding, numbers.Number): + padding = (padding, padding, padding, padding) + else: + raise ValueError('Padding must be a int or a 2, or 4 element tuple.' + f'But received {padding}') + + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + + border_type = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT + } + img = cv2.copyMakeBorder( + img, + padding[1], + padding[3], + padding[0], + padding[2], + border_type[padding_mode], + value=pad_val) + + return img + + def checkmaxshape(self, samples): + maxh, maxw = 0, 0 + for sample in samples: + h,w = sample['im_shape'] + if h>maxh: + maxh = h + if w>maxw: + maxw = w + return (maxh, maxw) def __call__(self, samples, context=None): num_max_boxes = max([len(s['gt_bbox']) for s in samples]) + num_max_boxes = max(self.minimum_gtnum, num_max_boxes) + if self.pad_img: + maxshape = self.checkmaxshape(samples) for sample in samples: + if self.pad_img: + img = sample['image'] + padimg = self._impad(img, shape=maxshape) + sample['image'] = padimg if self.return_gt_mask: sample['pad_gt_mask'] = np.zeros( (num_max_boxes, 1), dtype=np.float32) @@ -1142,6 +1244,17 @@ class PadGT(BaseOperator): if num_gt > 0: pad_diff[:num_gt] = sample['difficult'] sample['difficult'] = pad_diff + if 'gt_joints' in sample: + num_joints = sample['gt_joints'].shape[1] + pad_gt_joints = np.zeros((num_max_boxes, num_joints, 3), dtype=np.float32) + if num_gt > 0: + pad_gt_joints[:num_gt] = sample['gt_joints'] + sample['gt_joints'] = pad_gt_joints + if 'gt_areas' in sample: + pad_gt_areas = np.zeros((num_max_boxes, 1), dtype=np.float32) + if num_gt > 0: + pad_gt_areas[:num_gt, 0] = sample['gt_areas'] + sample['gt_areas'] = pad_gt_areas return samples diff --git a/ppdet/data/transform/keypoint_operators.py b/ppdet/data/transform/keypoint_operators.py index 9c7db162fc50b21e6b5ae529a033b395119fe68d..24cf63b886063ef9e1aa741cab3d33b36a1b123e 100644 --- a/ppdet/data/transform/keypoint_operators.py +++ b/ppdet/data/transform/keypoint_operators.py @@ -41,7 +41,7 @@ __all__ = [ 'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK', 'ToHeatmapsTopDown_UDP', 'TopDownEvalAffine', 'AugmentationbyInformantionDropping', 'SinglePoseAffine', 'NoiseJitter', - 'FlipPose' + 'FlipPose', 'PETR_Resize' ] @@ -65,38 +65,77 @@ class KeyPointFlip(object): """ - def __init__(self, flip_permutation, hmsize, flip_prob=0.5): + def __init__(self, flip_permutation, hmsize=None, flip_prob=0.5): super(KeyPointFlip, self).__init__() assert isinstance(flip_permutation, Sequence) self.flip_permutation = flip_permutation self.flip_prob = flip_prob self.hmsize = hmsize - def __call__(self, records): - image = records['image'] - kpts_lst = records['joints'] - mask_lst = records['mask'] - flip = np.random.random() < self.flip_prob - if flip: - image = image[:, ::-1] - for idx, hmsize in enumerate(self.hmsize): - if len(mask_lst) > idx: - mask_lst[idx] = mask_lst[idx][:, ::-1] + def _flipjoints(self, records, sizelst): + ''' + records['gt_joints'] is Sequence in higherhrnet + ''' + if not ('gt_joints' in records and records['gt_joints'].size > 0): + return records + + kpts_lst = records['gt_joints'] + if isinstance(kpts_lst, Sequence): + for idx, hmsize in enumerate(sizelst): if kpts_lst[idx].ndim == 3: kpts_lst[idx] = kpts_lst[idx][:, self.flip_permutation] else: kpts_lst[idx] = kpts_lst[idx][self.flip_permutation] kpts_lst[idx][..., 0] = hmsize - kpts_lst[idx][..., 0] - kpts_lst[idx] = kpts_lst[idx].astype(np.int64) - kpts_lst[idx][kpts_lst[idx][..., 0] >= hmsize, 2] = 0 - kpts_lst[idx][kpts_lst[idx][..., 1] >= hmsize, 2] = 0 - kpts_lst[idx][kpts_lst[idx][..., 0] < 0, 2] = 0 - kpts_lst[idx][kpts_lst[idx][..., 1] < 0, 2] = 0 - records['image'] = image - records['joints'] = kpts_lst + else: + hmsize = sizelst[0] + if kpts_lst.ndim == 3: + kpts_lst = kpts_lst[:, self.flip_permutation] + else: + kpts_lst = kpts_lst[self.flip_permutation] + kpts_lst[..., 0] = hmsize - kpts_lst[..., 0] + + records['gt_joints'] = kpts_lst + return records + + def _flipmask(self, records, sizelst): + if not 'mask' in records: + return records + + mask_lst = records['mask'] + for idx, hmsize in enumerate(sizelst): + if len(mask_lst) > idx: + mask_lst[idx] = mask_lst[idx][:, ::-1] records['mask'] = mask_lst return records + def _flipbbox(self, records, sizelst): + if not 'gt_bbox' in records: + return records + + bboxes = records['gt_bbox'] + hmsize = sizelst[0] + bboxes[:, 0::2] = hmsize - bboxes[:, 0::2][:, ::-1] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, hmsize) + records['gt_bbox'] = bboxes + return records + + def __call__(self, records): + flip = np.random.random() < self.flip_prob + if flip: + image = records['image'] + image = image[:, ::-1] + records['image'] = image + if self.hmsize is None: + sizelst = [image.shape[1]] + else: + sizelst = self.hmsize + self._flipjoints(records, sizelst) + self._flipmask(records, sizelst) + self._flipbbox(records, sizelst) + + return records + @register_keypointop class RandomAffine(object): @@ -121,9 +160,10 @@ class RandomAffine(object): max_degree=30, scale=[0.75, 1.5], max_shift=0.2, - hmsize=[128, 256], + hmsize=None, trainsize=512, - scale_type='short'): + scale_type='short', + boldervalue=[114, 114, 114]): super(RandomAffine, self).__init__() self.max_degree = max_degree self.min_scale = scale[0] @@ -132,8 +172,9 @@ class RandomAffine(object): self.hmsize = hmsize self.trainsize = trainsize self.scale_type = scale_type + self.boldervalue = boldervalue - def _get_affine_matrix(self, center, scale, res, rot=0): + def _get_affine_matrix_old(self, center, scale, res, rot=0): """Generate transformation matrix.""" h = scale t = np.zeros((3, 3), dtype=np.float32) @@ -159,21 +200,94 @@ class RandomAffine(object): t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) return t + def _get_affine_matrix(self, center, scale, res, rot=0): + """Generate transformation matrix.""" + w, h = scale + t = np.zeros((3, 3), dtype=np.float32) + t[0, 0] = float(res[0]) / w + t[1, 1] = float(res[1]) / h + t[0, 2] = res[0] * (-float(center[0]) / w + .5) + t[1, 2] = res[1] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if rot != 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3), dtype=np.float32) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[0] / 2 + t_mat[1, 2] = -res[1] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + def _affine_joints_mask(self, + degree, + center, + roi_size, + dsize, + keypoints=None, + heatmap_mask=None, + gt_bbox=None): + kpts = None + mask = None + bbox = None + mask_affine_mat = self._get_affine_matrix(center, roi_size, dsize, + degree)[:2] + if heatmap_mask is not None: + mask = cv2.warpAffine(heatmap_mask, mask_affine_mat, dsize) + mask = ((mask / 255) > 0.5).astype(np.float32) + if keypoints is not None: + kpts = copy.deepcopy(keypoints) + kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(), + mask_affine_mat) + kpts[(kpts[..., 0]) > dsize[0], :] = 0 + kpts[(kpts[..., 1]) > dsize[1], :] = 0 + kpts[(kpts[..., 0]) < 0, :] = 0 + kpts[(kpts[..., 1]) < 0, :] = 0 + if gt_bbox is not None: + temp_bbox = gt_bbox[:, [0, 3, 2, 1]] + cat_bbox = np.concatenate((gt_bbox, temp_bbox), axis=-1) + gt_bbox_warped = warp_affine_joints(cat_bbox, mask_affine_mat) + bbox = np.zeros_like(gt_bbox) + bbox[:, 0] = gt_bbox_warped[:, 0::2].min(1).clip(0, dsize[0]) + bbox[:, 2] = gt_bbox_warped[:, 0::2].max(1).clip(0, dsize[0]) + bbox[:, 1] = gt_bbox_warped[:, 1::2].min(1).clip(0, dsize[1]) + bbox[:, 3] = gt_bbox_warped[:, 1::2].max(1).clip(0, dsize[1]) + return kpts, mask, bbox + def __call__(self, records): image = records['image'] - keypoints = records['joints'] - heatmap_mask = records['mask'] + shape = np.array(image.shape[:2][::-1]) + keypoints = None + heatmap_mask = None + gt_bbox = None + if 'gt_joints' in records: + keypoints = records['gt_joints'] + + if 'mask' in records: + heatmap_mask = records['mask'] + heatmap_mask *= 255 + + if 'gt_bbox' in records: + gt_bbox = records['gt_bbox'] degree = (np.random.random() * 2 - 1) * self.max_degree - shape = np.array(image.shape[:2][::-1]) center = center = np.array((np.array(shape) / 2)) aug_scale = np.random.random() * (self.max_scale - self.min_scale ) + self.min_scale if self.scale_type == 'long': - scale = max(shape[0], shape[1]) / 1.0 + scale = np.array([max(shape[0], shape[1]) / 1.0] * 2) elif self.scale_type == 'short': - scale = min(shape[0], shape[1]) / 1.0 + scale = np.array([min(shape[0], shape[1]) / 1.0] * 2) + elif self.scale_type == 'wh': + scale = shape else: raise ValueError('Unknown scale type: {}'.format(self.scale_type)) roi_size = aug_scale * scale @@ -181,44 +295,55 @@ class RandomAffine(object): dy = int(0) if self.max_shift > 0: - dx = np.random.randint(-self.max_shift * roi_size, - self.max_shift * roi_size) - dy = np.random.randint(-self.max_shift * roi_size, - self.max_shift * roi_size) + dx = np.random.randint(-self.max_shift * roi_size[0], + self.max_shift * roi_size[0]) + dy = np.random.randint(-self.max_shift * roi_size[0], + self.max_shift * roi_size[1]) center += np.array([dx, dy]) input_size = 2 * center + if self.trainsize != -1: + dsize = self.trainsize + imgshape = (dsize, dsize) + else: + dsize = scale + imgshape = (shape.tolist()) - keypoints[..., :2] *= shape - heatmap_mask *= 255 - kpts_lst = [] - mask_lst = [] - - image_affine_mat = self._get_affine_matrix( - center, roi_size, (self.trainsize, self.trainsize), degree)[:2] + image_affine_mat = self._get_affine_matrix(center, roi_size, dsize, + degree)[:2] image = cv2.warpAffine( image, - image_affine_mat, (self.trainsize, self.trainsize), - flags=cv2.INTER_LINEAR) + image_affine_mat, + imgshape, + flags=cv2.INTER_LINEAR, + borderValue=self.boldervalue) + + if self.hmsize is None: + kpts, mask, gt_bbox = self._affine_joints_mask( + degree, center, roi_size, dsize, keypoints, heatmap_mask, + gt_bbox) + records['image'] = image + if kpts is not None: records['gt_joints'] = kpts + if mask is not None: records['mask'] = mask + if gt_bbox is not None: records['gt_bbox'] = gt_bbox + return records + + kpts_lst = [] + mask_lst = [] for hmsize in self.hmsize: - kpts = copy.deepcopy(keypoints) - mask_affine_mat = self._get_affine_matrix( - center, roi_size, (hmsize, hmsize), degree)[:2] - if heatmap_mask is not None: - mask = cv2.warpAffine(heatmap_mask, mask_affine_mat, - (hmsize, hmsize)) - mask = ((mask / 255) > 0.5).astype(np.float32) - kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(), - mask_affine_mat) - kpts[np.trunc(kpts[..., 0]) >= hmsize, 2] = 0 - kpts[np.trunc(kpts[..., 1]) >= hmsize, 2] = 0 - kpts[np.trunc(kpts[..., 0]) < 0, 2] = 0 - kpts[np.trunc(kpts[..., 1]) < 0, 2] = 0 + kpts, mask, gt_bbox = self._affine_joints_mask( + degree, center, roi_size, [hmsize, hmsize], keypoints, + heatmap_mask, gt_bbox) kpts_lst.append(kpts) mask_lst.append(mask) records['image'] = image - records['joints'] = kpts_lst - records['mask'] = mask_lst + + if 'gt_joints' in records: + records['gt_joints'] = kpts_lst + if 'mask' in records: + records['mask'] = mask_lst + if 'gt_bbox' in records: + records['gt_bbox'] = gt_bbox return records @@ -251,8 +376,8 @@ class EvalAffine(object): if mask is not None: mask = cv2.warpAffine(mask, trans, size_resized) records['mask'] = mask - if 'joints' in records: - del records['joints'] + if 'gt_joints' in records: + del records['gt_joints'] records['image'] = image_resized return records @@ -303,7 +428,7 @@ class TagGenerate(object): self.num_joints = num_joints def __call__(self, records): - kpts_lst = records['joints'] + kpts_lst = records['gt_joints'] kpts = kpts_lst[0] tagmap = np.zeros((self.max_people, self.num_joints, 4), dtype=np.int64) inds = np.where(kpts[..., 2] > 0) @@ -315,7 +440,7 @@ class TagGenerate(object): tagmap[p, j, 2] = visible[..., 0] # x tagmap[p, j, 3] = 1 records['tagmap'] = tagmap - del records['joints'] + del records['gt_joints'] return records @@ -349,7 +474,7 @@ class ToHeatmaps(object): self.gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) def __call__(self, records): - kpts_lst = records['joints'] + kpts_lst = records['gt_joints'] mask_lst = records['mask'] for idx, hmsize in enumerate(self.hmsize): mask = mask_lst[idx] @@ -470,7 +595,7 @@ class RandomFlipHalfBodyTransform(object): def __call__(self, records): image = records['image'] - joints = records['joints'] + joints = records['gt_joints'] joints_vis = records['joints_vis'] c = records['center'] s = records['scale'] @@ -493,7 +618,7 @@ class RandomFlipHalfBodyTransform(object): joints, joints_vis, image.shape[1], self.flip_pairs) c[0] = image.shape[1] - c[0] - 1 records['image'] = image - records['joints'] = joints + records['gt_joints'] = joints records['joints_vis'] = joints_vis records['center'] = c records['scale'] = s @@ -553,7 +678,7 @@ class AugmentationbyInformantionDropping(object): def __call__(self, records): img = records['image'] - joints = records['joints'] + joints = records['gt_joints'] joints_vis = records['joints_vis'] if np.random.rand() < self.prob_cutout: img = self._cutout(img, joints, joints_vis) @@ -581,7 +706,7 @@ class TopDownAffine(object): def __call__(self, records): image = records['image'] - joints = records['joints'] + joints = records['gt_joints'] joints_vis = records['joints_vis'] rot = records['rotate'] if "rotate" in records else 0 if self.use_udp: @@ -606,7 +731,7 @@ class TopDownAffine(object): joints[i, 0:2] = affine_transform(joints[i, 0:2], trans) records['image'] = image - records['joints'] = joints + records['gt_joints'] = joints return records @@ -842,7 +967,7 @@ class ToHeatmapsTopDown(object): https://github.com/leoxiaobin/deep-high-resolution-net.pytorch Copyright (c) Microsoft, under the MIT License. """ - joints = records['joints'] + joints = records['gt_joints'] joints_vis = records['joints_vis'] num_joints = joints.shape[0] image_size = np.array( @@ -885,7 +1010,7 @@ class ToHeatmapsTopDown(object): 0]:g_y[1], g_x[0]:g_x[1]] records['target'] = target records['target_weight'] = target_weight - del records['joints'], records['joints_vis'] + del records['gt_joints'], records['joints_vis'] return records @@ -910,7 +1035,7 @@ class ToHeatmapsTopDown_DARK(object): self.sigma = sigma def __call__(self, records): - joints = records['joints'] + joints = records['gt_joints'] joints_vis = records['joints_vis'] num_joints = joints.shape[0] image_size = np.array( @@ -943,7 +1068,7 @@ class ToHeatmapsTopDown_DARK(object): (x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2)) records['target'] = target records['target_weight'] = target_weight - del records['joints'], records['joints_vis'] + del records['gt_joints'], records['joints_vis'] return records @@ -972,7 +1097,7 @@ class ToHeatmapsTopDown_UDP(object): self.sigma = sigma def __call__(self, records): - joints = records['joints'] + joints = records['gt_joints'] joints_vis = records['joints_vis'] num_joints = joints.shape[0] image_size = np.array( @@ -1017,6 +1142,472 @@ class ToHeatmapsTopDown_UDP(object): 0]:g_y[1], g_x[0]:g_x[1]] records['target'] = target records['target_weight'] = target_weight - del records['joints'], records['joints_vis'] + del records['gt_joints'], records['joints_vis'] return records + + +from typing import Optional, Tuple, Union, List +import numbers + + +def _scale_size( + size: Tuple[int, int], + scale: Union[float, int, tuple], ) -> Tuple[int, int]: + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +def rescale_size(old_size: tuple, + scale: Union[float, int, tuple], + return_scale: bool=False) -> tuple: + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f'Invalid scale {scale}, must be positive.') + scale_factor = scale + elif isinstance(scale, list): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError( + f'Scale must be a number or tuple of int, but got {type(scale)}') + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale(img: np.ndarray, + scale: Union[float, Tuple[int, int]], + return_scale: bool=False, + interpolation: str='bilinear', + backend: Optional[str]=None) -> Union[np.ndarray, Tuple[ + np.ndarray, float]]: + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize( + img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +def imresize( + img: np.ndarray, + size: Tuple[int, int], + return_scale: bool=False, + interpolation: str='bilinear', + out: Optional[np.ndarray]=None, + backend: Optional[str]=None, + interp=cv2.INTER_LINEAR, ) -> Union[Tuple[np.ndarray, float, float], + np.ndarray]: + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend is None: + backend = imread_backend + if backend not in ['cv2', 'pillow']: + raise ValueError(f'backend: {backend} is not supported for resize.' + f"Supported backends are 'cv2', 'pillow'") + + if backend == 'pillow': + assert img.dtype == np.uint8, 'Pillow backend only support uint8 type' + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize(img, size, dst=out, interpolation=interp) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +class PETR_Resize: + """Resize images & bbox & mask. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. If the input dict contains the key + "scale", then the scale in the input dict is used, otherwise the specified + scale in the init method is used. If the input dict contains the key + "scale_factor" (if MultiScaleFlipAug does not give img_scale but + scale_factor), the actual scale will be computed by image shape and + scale_factor. + + `img_scale` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio \ + range and multiply it with the image scale. + - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ + sample a scale from the multiscale range. + - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ + sample a scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + bbox_clip_border=True, + backend='cv2', + interpolation='bilinear', + override=False, + keypoint_clip_border=True): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert isinstance(self.img_scale, list) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.backend = backend + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + # TODO: refactor the override option in Resize + self.interpolation = interpolation + self.override = override + self.bbox_clip_border = bbox_clip_border + self.keypoint_clip_border = keypoint_clip_border + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ + where ``img_scale`` is the selected image scale and \ + ``scale_idx`` is the selected index in the given candidates. + """ + + assert isinstance(img_scales, list) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where \ + ``img_scale`` is sampled scale and None is just a placeholder \ + to be consistent with :func:`random_select`. + """ + + assert isinstance(img_scales, list) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (list): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where \ + ``scale`` is sampled ratio multiplied with ``img_scale`` and \ + None is just a placeholder to be consistent with \ + :func:`random_select`. + """ + + assert isinstance(img_scale, list) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into \ + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio(self.img_scale[0], + self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + for key in ['image'] if 'image' in results else []: + if self.keep_ratio: + img, scale_factor = imrescale( + results[key], + results['scale'], + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results[key].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = imresize( + results[key], + results['scale'], + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + results['im_shape'] = np.array(img.shape) + # in case that there is no padding + results['pad_shape'] = img.shape + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + # img_pad = self.impad(img, shape=results['scale']) + results[key] = img + + def _resize_bboxes(self, results): + """Resize bounding boxes with ``results['scale_factor']``.""" + for key in ['gt_bbox'] if 'gt_bbox' in results else []: + bboxes = results[key] * results['scale_factor'] + if self.bbox_clip_border: + img_shape = results['im_shape'] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + results[key] = bboxes + + def _resize_masks(self, results): + """Resize masks with ``results['scale']``""" + for key in ['mask'] if 'mask' in results else []: + if results[key] is None: + continue + if self.keep_ratio: + results[key] = results[key].rescale(results['scale']) + else: + results[key] = results[key].resize(results['im_shape'][:2]) + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in ['seg'] if 'seg' in results else []: + if self.keep_ratio: + gt_seg = imrescale( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = imresize( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results[key] = gt_seg + + def _resize_keypoints(self, results): + """Resize keypoints with ``results['scale_factor']``.""" + for key in ['gt_joints'] if 'gt_joints' in results else []: + keypoints = results[key].copy() + keypoints[..., 0] = keypoints[..., 0] * results['scale_factor'][0] + keypoints[..., 1] = keypoints[..., 1] * results['scale_factor'][1] + if self.keypoint_clip_border: + img_shape = results['im_shape'] + keypoints[..., 0] = np.clip(keypoints[..., 0], 0, img_shape[1]) + keypoints[..., 1] = np.clip(keypoints[..., 1], 0, img_shape[0]) + results[key] = keypoints + + def _resize_areas(self, results): + """Resize mask areas with ``results['scale_factor']``.""" + for key in ['gt_areas'] if 'gt_areas' in results else []: + areas = results[key].copy() + areas = areas * results['scale_factor'][0] * results[ + 'scale_factor'][1] + results[key] = areas + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'im_shape', 'pad_shape', 'scale_factor', \ + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + if 'scale_factor' in results: + img_shape = results['image'].shape[:2] + scale_factor = results['scale_factor'] + assert isinstance(scale_factor, float) + results['scale'] = tuple( + [int(x * scale_factor) for x in img_shape][::-1]) + else: + self._random_scale(results) + else: + if not self.override: + assert 'scale_factor' not in results, ( + 'scale and scale_factor cannot be both set.') + else: + results.pop('scale') + if 'scale_factor' in results: + results.pop('scale_factor') + self._random_scale(results) + + self._resize_img(results) + self._resize_bboxes(results) + self._resize_masks(results) + self._resize_seg(results) + self._resize_keypoints(results) + self._resize_areas(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'multiscale_mode={self.multiscale_mode}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + repr_str += f'keypoint_clip_border={self.keypoint_clip_border})' + return repr_str diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 2f57cdfe33f77569aeb3546fc6ec8227a5b1a65e..61a4aacba024e7b81cfd832ae219d6cfa05af09e 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -594,6 +594,108 @@ class RandomDistort(BaseOperator): return sample +@register_op +class PhotoMetricDistortion(BaseOperator): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + super(PhotoMetricDistortion, self).__init__() + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def apply(self, results, context=None): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['image'] + img = img.astype(np.float32) + # random brightness + if np.random.randint(2): + delta = np.random.uniform(-self.brightness_delta, + self.brightness_delta) + img += delta + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = np.random.randint(2) + if mode == 1: + if np.random.randint(2): + alpha = np.random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + + # convert color from BGR to HSV + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + + # random saturation + if np.random.randint(2): + img[..., 1] *= np.random.uniform(self.saturation_lower, + self.saturation_upper) + + # random hue + if np.random.randint(2): + img[..., 0] += np.random.uniform(-self.hue_delta, self.hue_delta) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + + # convert color from HSV to BGR + img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) + + # random contrast + if mode == 0: + if np.random.randint(2): + alpha = np.random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + + # randomly swap channels + if np.random.randint(2): + img = img[..., np.random.permutation(3)] + + results['image'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(\nbrightness_delta={self.brightness_delta},\n' + repr_str += 'contrast_range=' + repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n' + repr_str += 'saturation_range=' + repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n' + repr_str += f'hue_delta={self.hue_delta})' + return repr_str + + @register_op class AutoAugment(BaseOperator): def __init__(self, autoaug_type="v1"): @@ -771,6 +873,19 @@ class Resize(BaseOperator): bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h) return bbox + def apply_area(self, area, scale): + im_scale_x, im_scale_y = scale + return area * im_scale_x * im_scale_y + + def apply_joints(self, joints, scale, size): + im_scale_x, im_scale_y = scale + resize_w, resize_h = size + joints[..., 0] *= im_scale_x + joints[..., 1] *= im_scale_y + joints[..., 0] = np.clip(joints[..., 0], 0, resize_w) + joints[..., 1] = np.clip(joints[..., 1], 0, resize_h) + return joints + def apply_segm(self, segms, im_size, scale): def _resize_poly(poly, im_scale_x, im_scale_y): resized_poly = np.array(poly).astype('float32') @@ -833,8 +948,8 @@ class Resize(BaseOperator): im_scale = min(target_size_min / im_size_min, target_size_max / im_size_max) - resize_h = im_scale * float(im_shape[0]) - resize_w = im_scale * float(im_shape[1]) + resize_h = int(im_scale * float(im_shape[0]) + 0.5) + resize_w = int(im_scale * float(im_shape[1]) + 0.5) im_scale_x = im_scale im_scale_y = im_scale @@ -878,6 +993,11 @@ class Resize(BaseOperator): [im_scale_x, im_scale_y], [resize_w, resize_h]) + # apply areas + if 'gt_areas' in sample: + sample['gt_areas'] = self.apply_area(sample['gt_areas'], + [im_scale_x, im_scale_y]) + # apply polygon if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2], @@ -911,6 +1031,11 @@ class Resize(BaseOperator): ] sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + if 'gt_joints' in sample: + sample['gt_joints'] = self.apply_joints(sample['gt_joints'], + [im_scale_x, im_scale_y], + [resize_w, resize_h]) + return sample @@ -1362,7 +1487,8 @@ class RandomCrop(BaseOperator): num_attempts=50, allow_no_crop=True, cover_all_box=False, - is_mask_crop=False): + is_mask_crop=False, + ioumode="iou"): super(RandomCrop, self).__init__() self.aspect_ratio = aspect_ratio self.thresholds = thresholds @@ -1371,6 +1497,7 @@ class RandomCrop(BaseOperator): self.allow_no_crop = allow_no_crop self.cover_all_box = cover_all_box self.is_mask_crop = is_mask_crop + self.ioumode = ioumode def crop_segms(self, segms, valid_ids, crop, height, width): def _crop_poly(segm, crop): @@ -1516,9 +1643,14 @@ class RandomCrop(BaseOperator): crop_y = np.random.randint(0, h - crop_h) crop_x = np.random.randint(0, w - crop_w) crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] - iou = self._iou_matrix( - gt_bbox, np.array( - [crop_box], dtype=np.float32)) + if self.ioumode == "iof": + iou = self._gtcropiou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) + elif self.ioumode == "iou": + iou = self._iou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) if iou.max() < thresh: continue @@ -1582,6 +1714,10 @@ class RandomCrop(BaseOperator): sample['difficult'] = np.take( sample['difficult'], valid_ids, axis=0) + if 'gt_joints' in sample: + sample['gt_joints'] = self._crop_joints(sample['gt_joints'], + crop_box) + return sample return sample @@ -1596,6 +1732,16 @@ class RandomCrop(BaseOperator): area_o = (area_a[:, np.newaxis] + area_b - area_i) return area_i / (area_o + 1e-10) + def _gtcropiou_matrix(self, a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_a + 1e-10) + def _crop_box_with_center_constraint(self, box, crop): cropped_box = box.copy() @@ -1620,6 +1766,16 @@ class RandomCrop(BaseOperator): x1, y1, x2, y2 = crop return segm[:, y1:y2, x1:x2] + def _crop_joints(self, joints, crop): + x1, y1, x2, y2 = crop + joints[joints[..., 0] > x2, :] = 0 + joints[joints[..., 1] > y2, :] = 0 + joints[joints[..., 0] < x1, :] = 0 + joints[joints[..., 1] < y1, :] = 0 + joints[..., 0] -= x1 + joints[..., 1] -= y1 + return joints + @register_op class RandomScaledCrop(BaseOperator): @@ -1648,8 +1804,8 @@ class RandomScaledCrop(BaseOperator): random_dim = int(dim * random_scale) dim_max = max(h, w) scale = random_dim / dim_max - resize_w = w * scale - resize_h = h * scale + resize_w = int(w * scale + 0.5) + resize_h = int(h * scale + 0.5) offset_x = int(max(0, np.random.uniform(0., resize_w - dim))) offset_y = int(max(0, np.random.uniform(0., resize_h - dim))) @@ -2316,25 +2472,26 @@ class RandomResizeCrop(BaseOperator): is_mask_crop(bool): whether crop the segmentation. """ - def __init__( - self, - resizes, - cropsizes, - prob=0.5, - mode='short', - keep_ratio=True, - interp=cv2.INTER_LINEAR, - num_attempts=3, - cover_all_box=False, - allow_no_crop=False, - thresholds=[0.3, 0.5, 0.7], - is_mask_crop=False, ): + def __init__(self, + resizes, + cropsizes, + prob=0.5, + mode='short', + keep_ratio=True, + interp=cv2.INTER_LINEAR, + num_attempts=3, + cover_all_box=False, + allow_no_crop=False, + thresholds=[0.3, 0.5, 0.7], + is_mask_crop=False, + ioumode="iou"): super(RandomResizeCrop, self).__init__() self.resizes = resizes self.cropsizes = cropsizes self.prob = prob self.mode = mode + self.ioumode = ioumode self.resizer = Resize(0, keep_ratio=keep_ratio, interp=interp) self.croper = RandomCrop( @@ -2389,9 +2546,14 @@ class RandomResizeCrop(BaseOperator): crop_x = random.randint(0, w - crop_w) crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] - iou = self._iou_matrix( - gt_bbox, np.array( - [crop_box], dtype=np.float32)) + if self.ioumode == "iof": + iou = self._gtcropiou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) + elif self.ioumode == "iou": + iou = self._iou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) if iou.max() < thresh: continue @@ -2447,6 +2609,14 @@ class RandomResizeCrop(BaseOperator): if 'is_crowd' in sample: sample['is_crowd'] = np.take( sample['is_crowd'], valid_ids, axis=0) + + if 'gt_areas' in sample: + sample['gt_areas'] = np.take( + sample['gt_areas'], valid_ids, axis=0) + + if 'gt_joints' in sample: + gt_joints = self._crop_joints(sample['gt_joints'], crop_box) + sample['gt_joints'] = gt_joints[valid_ids] return sample return sample @@ -2479,8 +2649,8 @@ class RandomResizeCrop(BaseOperator): im_scale = max(target_size_min / im_size_min, target_size_max / im_size_max) - resize_h = im_scale * float(im_shape[0]) - resize_w = im_scale * float(im_shape[1]) + resize_h = int(im_scale * float(im_shape[0]) + 0.5) + resize_w = int(im_scale * float(im_shape[1]) + 0.5) im_scale_x = im_scale im_scale_y = im_scale @@ -2540,6 +2710,11 @@ class RandomResizeCrop(BaseOperator): ] sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + if 'gt_joints' in sample: + sample['gt_joints'] = self.apply_joints(sample['gt_joints'], + [im_scale_x, im_scale_y], + [resize_w, resize_h]) + return sample @@ -2612,10 +2787,10 @@ class RandomShortSideResize(BaseOperator): if w < h: ow = size - oh = int(size * h / w) + oh = int(round(size * h / w)) else: oh = size - ow = int(size * w / h) + ow = int(round(size * w / h)) return (ow, oh) @@ -2672,6 +2847,16 @@ class RandomShortSideResize(BaseOperator): for gt_segm in sample['gt_segm'] ] sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + + if 'gt_joints' in sample: + sample['gt_joints'] = self.apply_joints( + sample['gt_joints'], [im_scale_x, im_scale_y], target_size) + + # apply areas + if 'gt_areas' in sample: + sample['gt_areas'] = self.apply_area(sample['gt_areas'], + [im_scale_x, im_scale_y]) + return sample def apply_bbox(self, bbox, scale, size): @@ -2683,6 +2868,23 @@ class RandomShortSideResize(BaseOperator): bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h) return bbox.astype('float32') + def apply_joints(self, joints, scale, size): + im_scale_x, im_scale_y = scale + resize_w, resize_h = size + joints[..., 0] *= im_scale_x + joints[..., 1] *= im_scale_y + # joints[joints[..., 0] >= resize_w, :] = 0 + # joints[joints[..., 1] >= resize_h, :] = 0 + # joints[joints[..., 0] < 0, :] = 0 + # joints[joints[..., 1] < 0, :] = 0 + joints[..., 0] = np.clip(joints[..., 0], 0, resize_w) + joints[..., 1] = np.clip(joints[..., 1], 0, resize_h) + return joints + + def apply_area(self, area, scale): + im_scale_x, im_scale_y = scale + return area * im_scale_x * im_scale_y + def apply_segm(self, segms, im_size, scale): def _resize_poly(poly, im_scale_x, im_scale_y): resized_poly = np.array(poly).astype('float32') @@ -2730,6 +2932,44 @@ class RandomShortSideResize(BaseOperator): return self.resize(sample, target_size, self.max_size, interp) +@register_op +class RandomShortSideRangeResize(RandomShortSideResize): + def __init__(self, scales, interp=cv2.INTER_LINEAR, random_interp=False): + """ + Resize the image randomly according to the short side. If max_size is not None, + the long side is scaled according to max_size. The whole process will be keep ratio. + Args: + short_side_sizes (list|tuple): Image target short side size. + interp (int): The interpolation method. + random_interp (bool): Whether random select interpolation method. + """ + super(RandomShortSideRangeResize, self).__init__(scales, None, interp, + random_interp) + + assert isinstance(scales, + Sequence), "short_side_sizes must be List or Tuple" + + self.scales = scales + + def random_sample(self, img_scales): + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale + + def apply(self, sample, context=None): + long_edge, short_edge = self.random_sample(self.short_side_sizes) + # print("target size:{}".format((long_edge, short_edge))) + interp = random.choice( + self.interps) if self.random_interp else self.interp + + return self.resize(sample, short_edge, long_edge, interp) + + @register_op class RandomSizeCrop(BaseOperator): """ @@ -2805,6 +3045,9 @@ class RandomSizeCrop(BaseOperator): sample['is_crowd'] = sample['is_crowd'][keep_index] if len( keep_index) > 0 else np.zeros( [0, 1], dtype=np.float32) + if 'gt_areas' in sample: + sample['gt_areas'] = np.take( + sample['gt_areas'], keep_index, axis=0) image_shape = sample['image'].shape[:2] sample['image'] = self.paddle_crop(sample['image'], *region) @@ -2826,6 +3069,12 @@ class RandomSizeCrop(BaseOperator): if keep_index is not None and len(keep_index) > 0: sample['gt_segm'] = sample['gt_segm'][keep_index] + if 'gt_joints' in sample: + gt_joints = self._crop_joints(sample['gt_joints'], region) + sample['gt_joints'] = gt_joints + if keep_index is not None: + sample['gt_joints'] = sample['gt_joints'][keep_index] + return sample def apply_bbox(self, bbox, region): @@ -2836,6 +3085,19 @@ class RandomSizeCrop(BaseOperator): crop_bbox = crop_bbox.clip(min=0) return crop_bbox.reshape([-1, 4]).astype('float32') + def _crop_joints(self, joints, region): + y1, x1, h, w = region + x2 = x1 + w + y2 = y1 + h + # x1, y1, x2, y2 = crop + joints[..., 0] -= x1 + joints[..., 1] -= y1 + joints[joints[..., 0] > w, :] = 0 + joints[joints[..., 1] > h, :] = 0 + joints[joints[..., 0] < 0, :] = 0 + joints[joints[..., 1] < 0, :] = 0 + return joints + def apply_segm(self, segms, region, image_shape): def _crop_poly(segm, crop): xmin, ymin, xmax, ymax = crop diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 5efdec0775ed46c4328ec90e7ccf483f79e64725..8899e5c0b4cb6957810dcbce20c35f55f0dcbdf2 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -72,3 +72,4 @@ from .yolof import * from .pose3d_metro import * from .centertrack import * from .queryinst import * +from .keypoint_petr import * diff --git a/ppdet/modeling/architectures/keypoint_hrnet.py b/ppdet/modeling/architectures/keypoint_hrnet.py index fa3541d7d783b70fab8eb28dbdd8914b7394f6b4..1d93e3af5f5d4e4b0be173dd64ea37f01f7b31be 100644 --- a/ppdet/modeling/architectures/keypoint_hrnet.py +++ b/ppdet/modeling/architectures/keypoint_hrnet.py @@ -394,6 +394,7 @@ class TinyPose3DHRNet(BaseArch): def __init__(self, width, num_joints, + fc_channel=768, backbone='HRNet', loss='KeyPointRegressionMSELoss', post_process=TinyPose3DPostProcess): @@ -411,21 +412,13 @@ class TinyPose3DHRNet(BaseArch): self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True) - self.final_conv_new = L.Conv2d( - width, num_joints * 32, 1, 1, 0, bias=True) - self.flatten = paddle.nn.Flatten(start_axis=2, stop_axis=3) - self.fc1 = paddle.nn.Linear(768, 256) + self.fc1 = paddle.nn.Linear(fc_channel, 256) self.act1 = paddle.nn.ReLU() self.fc2 = paddle.nn.Linear(256, 64) self.act2 = paddle.nn.ReLU() self.fc3 = paddle.nn.Linear(64, 3) - # for human3.6M - self.fc1_1 = paddle.nn.Linear(3136, 1024) - self.fc2_1 = paddle.nn.Linear(1024, 256) - self.fc3_1 = paddle.nn.Linear(256, 3) - @classmethod def from_config(cls, cfg, *args, **kwargs): # backbone diff --git a/ppdet/modeling/architectures/keypoint_petr.py b/ppdet/modeling/architectures/keypoint_petr.py new file mode 100644 index 0000000000000000000000000000000000000000..b587c1f0668c968371def398c08b5968839c5b6f --- /dev/null +++ b/ppdet/modeling/architectures/keypoint_petr.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/detectors/petr.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from ppdet.core.workspace import register +from .meta_arch import BaseArch +from .. import layers as L + +__all__ = ['PETR'] + + +@register +class PETR(BaseArch): + __category__ = 'architecture' + __inject__ = ['backbone', 'neck', 'bbox_head'] + + def __init__(self, + backbone='ResNet', + neck='ChannelMapper', + bbox_head='PETRHead'): + """ + PETR, see https://openaccess.thecvf.com/content/CVPR2022/papers/Shi_End-to-End_Multi-Person_Pose_Estimation_With_Transformers_CVPR_2022_paper.pdf + + Args: + backbone (nn.Layer): backbone instance + neck (nn.Layer): neck between backbone and head + bbox_head (nn.Layer): model output and loss + """ + super(PETR, self).__init__() + self.backbone = backbone + if neck is not None: + self.with_neck = True + self.neck = neck + self.bbox_head = bbox_head + self.deploy = False + + def extract_feat(self, img): + """Directly extract features from the backbone+neck.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def get_inputs(self): + img_metas = [] + gt_bboxes = [] + gt_labels = [] + gt_keypoints = [] + gt_areas = [] + pad_gt_mask = self.inputs['pad_gt_mask'].astype("bool").squeeze(-1) + for idx, im_shape in enumerate(self.inputs['im_shape']): + img_meta = { + 'img_shape': im_shape.astype("int32").tolist() + [1, ], + 'batch_input_shape': self.inputs['image'].shape[-2:], + 'image_name': self.inputs['image_file'][idx] + } + img_metas.append(img_meta) + if (not pad_gt_mask[idx].any()): + gt_keypoints.append(self.inputs['gt_joints'][idx][:1]) + gt_labels.append(self.inputs['gt_class'][idx][:1]) + gt_bboxes.append(self.inputs['gt_bbox'][idx][:1]) + gt_areas.append(self.inputs['gt_areas'][idx][:1]) + continue + + gt_keypoints.append(self.inputs['gt_joints'][idx][pad_gt_mask[idx]]) + gt_labels.append(self.inputs['gt_class'][idx][pad_gt_mask[idx]]) + gt_bboxes.append(self.inputs['gt_bbox'][idx][pad_gt_mask[idx]]) + gt_areas.append(self.inputs['gt_areas'][idx][pad_gt_mask[idx]]) + + return img_metas, gt_bboxes, gt_labels, gt_keypoints, gt_areas + + def get_loss(self): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box. + gt_keypoints (list[Tensor]): Each item are the truth keypoints for + each image in [p^{1}_x, p^{1}_y, p^{1}_v, ..., p^{K}_x, + p^{K}_y, p^{K}_v] format. + gt_areas (list[Tensor]): mask areas corresponding to each box. + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + img_metas, gt_bboxes, gt_labels, gt_keypoints, gt_areas = self.get_inputs( + ) + gt_bboxes_ignore = getattr(self.inputs, 'gt_bboxes_ignore', None) + + x = self.extract_feat(self.inputs) + losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_keypoints, gt_areas, + gt_bboxes_ignore) + loss = 0 + for k, v in losses.items(): + loss += v + losses['loss'] = loss + + return losses + + def get_pred_numpy(self): + """Used for computing network flops. + """ + + img = self.inputs['image'] + batch_size, _, height, width = img.shape + dummy_img_metas = [ + dict( + batch_input_shape=(height, width), + img_shape=(height, width, 3), + scale_factor=(1., 1., 1., 1.)) for _ in range(batch_size) + ] + x = self.extract_feat(img) + outs = self.bbox_head(x, img_metas=dummy_img_metas) + bbox_list = self.bbox_head.get_bboxes( + *outs, dummy_img_metas, rescale=True) + return bbox_list + + def get_pred(self): + """ + """ + img = self.inputs['image'] + batch_size, _, height, width = img.shape + img_metas = [ + dict( + batch_input_shape=(height, width), + img_shape=(height, width, 3), + scale_factor=self.inputs['scale_factor'][i]) + for i in range(batch_size) + ] + kptpred = self.simple_test( + self.inputs, img_metas=img_metas, rescale=True) + keypoints = kptpred[0][1][0] + bboxs = kptpred[0][0][0] + keypoints[..., 2] = bboxs[:, None, 4] + res_lst = [[keypoints, bboxs[:, 4]]] + outputs = {'keypoint': res_lst} + return outputs + + def simple_test(self, inputs, img_metas, rescale=False): + """Test function without test time augmentation. + + Args: + inputs (list[paddle.Tensor]): List of multiple images. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox and keypoint results of each image + and classes. The outer list corresponds to each image. + The inner list corresponds to each class. + """ + batch_size = len(img_metas) + assert batch_size == 1, 'Currently only batch_size 1 for inference ' \ + f'mode is supported. Found batch_size {batch_size}.' + feat = self.extract_feat(inputs) + results_list = self.bbox_head.simple_test( + feat, img_metas, rescale=rescale) + + bbox_kpt_results = [ + self.bbox_kpt2result(det_bboxes, det_labels, det_kpts, + self.bbox_head.num_classes) + for det_bboxes, det_labels, det_kpts in results_list + ] + return bbox_kpt_results + + def bbox_kpt2result(self, bboxes, labels, kpts, num_classes): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (paddle.Tensor | np.ndarray): shape (n, 5). + labels (paddle.Tensor | np.ndarray): shape (n, ). + kpts (paddle.Tensor | np.ndarray): shape (n, K, 3). + num_classes (int): class number, including background class. + + Returns: + list(ndarray): bbox and keypoint results of each class. + """ + if bboxes.shape[0] == 0: + return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)], \ + [np.zeros((0, kpts.size(1), 3), dtype=np.float32) + for i in range(num_classes)] + else: + if isinstance(bboxes, paddle.Tensor): + bboxes = bboxes.numpy() + labels = labels.numpy() + kpts = kpts.numpy() + return [bboxes[labels == i, :] for i in range(num_classes)], \ + [kpts[labels == i, :, :] for i in range(num_classes)] diff --git a/ppdet/modeling/assigners/__init__.py b/ppdet/modeling/assigners/__init__.py index da548298ae7ed9adcb3652d9502537de643d5535..f462a9fd35190148f0285686d691806f3af8f4e2 100644 --- a/ppdet/modeling/assigners/__init__.py +++ b/ppdet/modeling/assigners/__init__.py @@ -31,3 +31,5 @@ from .fcosr_assigner import * from .rotated_task_aligned_assigner import * from .task_aligned_assigner_cr import * from .uniform_assigner import * +from .hungarian_assigner import * +from .pose_utils import * diff --git a/ppdet/modeling/assigners/hungarian_assigner.py b/ppdet/modeling/assigners/hungarian_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..154c27ce978d5d959b7682e19a6c410dd8e9f0a4 --- /dev/null +++ b/ppdet/modeling/assigners/hungarian_assigner.py @@ -0,0 +1,316 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + +import paddle + +from ppdet.core.workspace import register + +__all__ = ['PoseHungarianAssigner', 'PseudoSampler'] + + +class AssignResult: + """Stores assignments between predicted and truth boxes. + + Attributes: + num_gts (int): the number of truth boxes considered when computing this + assignment + + gt_inds (LongTensor): for each predicted box indicates the 1-based + index of the assigned truth box. 0 means unassigned and -1 means + ignore. + + max_overlaps (FloatTensor): the iou between the predicted box and its + assigned truth box. + + labels (None | LongTensor): If specified, for each predicted box + indicates the category label of the assigned truth box. + """ + + def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.max_overlaps = max_overlaps + self.labels = labels + # Interface for possible user-defined properties + self._extra_properties = {} + + @property + def num_preds(self): + """int: the number of predictions in this assignment""" + return len(self.gt_inds) + + def set_extra_property(self, key, value): + """Set user-defined new property.""" + assert key not in self.info + self._extra_properties[key] = value + + def get_extra_property(self, key): + """Get user-defined property.""" + return self._extra_properties.get(key, None) + + @property + def info(self): + """dict: a dictionary of info about the object""" + basic_info = { + 'num_gts': self.num_gts, + 'num_preds': self.num_preds, + 'gt_inds': self.gt_inds, + 'max_overlaps': self.max_overlaps, + 'labels': self.labels, + } + basic_info.update(self._extra_properties) + return basic_info + + +@register +class PoseHungarianAssigner: + """Computes one-to-one matching between predictions and ground truth. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression oks cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt. + - positive integer: positive sample, index (1-based) of assigned gt. + + Args: + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + kpt_weight (int | float, optional): The scale factor for regression + L1 cost. Default 1.0. + oks_weight (int | float, optional): The scale factor for regression + oks cost. Default 1.0. + """ + __inject__ = ['cls_cost', 'kpt_cost', 'oks_cost'] + + def __init__(self, + cls_cost='ClassificationCost', + kpt_cost='KptL1Cost', + oks_cost='OksCost'): + self.cls_cost = cls_cost + self.kpt_cost = kpt_cost + self.oks_cost = oks_cost + + def assign(self, + cls_pred, + kpt_pred, + gt_labels, + gt_keypoints, + gt_areas, + img_meta, + eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + kpt_pred (Tensor): Predicted keypoints with normalized coordinates + (x_{i}, y_{i}), which are all in range [0, 1]. Shape + [num_query, K*2]. + gt_labels (Tensor): Label of `gt_keypoints`, shape (num_gt,). + gt_keypoints (Tensor): Ground truth keypoints with unnormalized + coordinates [p^{1}_x, p^{1}_y, p^{1}_v, ..., \ + p^{K}_x, p^{K}_y, p^{K}_v]. Shape [num_gt, K*3]. + gt_areas (Tensor): Ground truth mask areas, shape (num_gt,). + img_meta (dict): Meta information for current image. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + num_gts, num_kpts = gt_keypoints.shape[0], kpt_pred.shape[0] + if not gt_keypoints.astype('bool').any(): + num_gts = 0 + + # 1. assign -1 by default + assigned_gt_inds = paddle.full((num_kpts, ), -1, dtype="int64") + assigned_labels = paddle.full((num_kpts, ), -1, dtype="int64") + if num_gts == 0 or num_kpts == 0: + # No ground truth or keypoints, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + img_h, img_w, _ = img_meta['img_shape'] + factor = paddle.to_tensor( + [img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype).reshape( + (1, -1)) + + # 2. compute the weighted costs + # classification cost + cls_cost = self.cls_cost(cls_pred, gt_labels) + + # keypoint regression L1 cost + gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1, + 3)) + valid_kpt_flag = gt_keypoints_reshape[..., -1] + kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, + 2)) + normalize_gt_keypoints = gt_keypoints_reshape[ + ..., :2] / factor[:, :2].unsqueeze(0) + kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints, + valid_kpt_flag) + # keypoint OKS cost + kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, + 2)) + kpt_pred_tmp = kpt_pred_tmp * factor[:, :2].unsqueeze(0) + oks_cost = self.oks_cost(kpt_pred_tmp, gt_keypoints_reshape[..., :2], + valid_kpt_flag, gt_areas) + # weighted sum of above three costs + cost = cls_cost + kpt_cost + oks_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = paddle.to_tensor(matched_row_inds) + matched_col_inds = paddle.to_tensor(matched_col_inds) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][ + ..., 0].astype("int64") + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + +class SamplingResult: + """Bbox sampling result. + """ + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, + gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + if pos_inds.size > 0: + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = paddle.zeros( + gt_bboxes.shape, dtype=gt_bboxes.dtype).reshape((-1, 4)) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.reshape((-1, 4)) + + self.pos_gt_bboxes = paddle.index_select( + gt_bboxes, + self.pos_assigned_gt_inds.astype('int64'), + axis=0) + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + """paddle.Tensor: concatenated positive and negative boxes""" + return paddle.concat([self.pos_bboxes, self.neg_bboxes]) + + def __nice__(self): + data = self.info.copy() + data['pos_bboxes'] = data.pop('pos_bboxes').shape + data['neg_bboxes'] = data.pop('neg_bboxes').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_bboxes': self.pos_bboxes, + 'neg_bboxes': self.neg_bboxes, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } + + +@register +class PseudoSampler: + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + bboxes (paddle.Tensor): Bounding boxes + gt_bboxes (paddle.Tensor): Ground truth boxes + + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = paddle.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1) + neg_inds = paddle.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1) + gt_flags = paddle.zeros([bboxes.shape[0]], dtype='int32') + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, + assign_result, gt_flags) + return sampling_result diff --git a/ppdet/modeling/assigners/pose_utils.py b/ppdet/modeling/assigners/pose_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..313215a4dd4fc3a61f08a378a7ef598c74265f8d --- /dev/null +++ b/ppdet/modeling/assigners/pose_utils.py @@ -0,0 +1,275 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +import paddle.nn.functional as F + +from ppdet.core.workspace import register + +__all__ = ['KptL1Cost', 'OksCost', 'ClassificationCost'] + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +@register +class KptL1Cost(object): + """KptL1Cost. + + this function based on: https://github.com/hikvision-research/opera/blob/main/opera/core/bbox/match_costs/match_cost.py + + Args: + weight (int | float, optional): loss_weight. + """ + + def __init__(self, weight=1.0): + self.weight = weight + + def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag): + """ + Args: + kpt_pred (Tensor): Predicted keypoints with normalized coordinates + (x_{i}, y_{i}), which are all in range [0, 1]. Shape + [num_query, K, 2]. + gt_keypoints (Tensor): Ground truth keypoints with normalized + coordinates (x_{i}, y_{i}). Shape [num_gt, K, 2]. + valid_kpt_flag (Tensor): valid flag of ground truth keypoints. + Shape [num_gt, K]. + + Returns: + paddle.Tensor: kpt_cost value with weight. + """ + kpt_cost = [] + for i in range(len(gt_keypoints)): + if gt_keypoints[i].size == 0: + kpt_cost.append(kpt_pred.sum() * 0) + kpt_pred_tmp = kpt_pred.clone() + valid_flag = valid_kpt_flag[i] > 0 + valid_flag_expand = valid_flag.unsqueeze(0).unsqueeze(-1).expand_as( + kpt_pred_tmp) + if not valid_flag_expand.all(): + kpt_pred_tmp = masked_fill(kpt_pred_tmp, ~valid_flag_expand, 0) + cost = F.pairwise_distance( + kpt_pred_tmp.reshape((kpt_pred_tmp.shape[0], -1)), + gt_keypoints[i].reshape((-1, )).unsqueeze(0), + p=1, + keepdim=True) + avg_factor = paddle.clip( + valid_flag.astype('float32').sum() * 2, 1.0) + cost = cost / avg_factor + kpt_cost.append(cost) + kpt_cost = paddle.concat(kpt_cost, axis=1) + return kpt_cost * self.weight + + +@register +class OksCost(object): + """OksCost. + + this function based on: https://github.com/hikvision-research/opera/blob/main/opera/core/bbox/match_costs/match_cost.py + + Args: + num_keypoints (int): number of keypoints + weight (int | float, optional): loss_weight. + """ + + def __init__(self, num_keypoints=17, weight=1.0): + self.weight = weight + if num_keypoints == 17: + self.sigmas = np.array( + [ + .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, + 1.07, .87, .87, .89, .89 + ], + dtype=np.float32) / 10.0 + elif num_keypoints == 14: + self.sigmas = np.array( + [ + .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, + .89, .79, .79 + ], + dtype=np.float32) / 10.0 + else: + raise ValueError(f'Unsupported keypoints number {num_keypoints}') + + def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas): + """ + Args: + kpt_pred (Tensor): Predicted keypoints with unnormalized + coordinates (x_{i}, y_{i}). Shape [num_query, K, 2]. + gt_keypoints (Tensor): Ground truth keypoints with unnormalized + coordinates (x_{i}, y_{i}). Shape [num_gt, K, 2]. + valid_kpt_flag (Tensor): valid flag of ground truth keypoints. + Shape [num_gt, K]. + gt_areas (Tensor): Ground truth mask areas. Shape [num_gt,]. + + Returns: + paddle.Tensor: oks_cost value with weight. + """ + sigmas = paddle.to_tensor(self.sigmas) + variances = (sigmas * 2)**2 + + oks_cost = [] + assert len(gt_keypoints) == len(gt_areas) + for i in range(len(gt_keypoints)): + if gt_keypoints[i].size == 0: + oks_cost.append(kpt_pred.sum() * 0) + squared_distance = \ + (kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0)) ** 2 + \ + (kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2 + vis_flag = (valid_kpt_flag[i] > 0).astype('int') + vis_ind = vis_flag.nonzero(as_tuple=False)[:, 0] + num_vis_kpt = vis_ind.shape[0] + # assert num_vis_kpt > 0 + if num_vis_kpt == 0: + oks_cost.append(paddle.zeros((squared_distance.shape[0], 1))) + continue + area = gt_areas[i] + + squared_distance0 = squared_distance / (area * variances * 2) + squared_distance0 = paddle.index_select( + squared_distance0, vis_ind, axis=1) + squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1, + keepdim=True) + oks = squared_distance1 / num_vis_kpt + # The 1 is a constant that doesn't change the matching, so omitted. + oks_cost.append(-oks) + oks_cost = paddle.concat(oks_cost, axis=1) + return oks_cost * self.weight + + +@register +class ClassificationCost: + """ClsSoftmaxCost. + + Args: + weight (int | float, optional): loss_weight + """ + + def __init__(self, weight=1.): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_query, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + paddle.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@register +class FocalLossCost: + """FocalLossCost. + + Args: + weight (int | float, optional): loss_weight + alpha (int | float, optional): focal_loss alpha + gamma (int | float, optional): focal_loss gamma + eps (float, optional): default 1e-12 + binary_input (bool, optional): Whether the input is binary, + default False. + """ + + def __init__(self, + weight=1., + alpha=0.25, + gamma=2, + eps=1e-12, + binary_input=False): + self.weight = weight + self.alpha = alpha + self.gamma = gamma + self.eps = eps + self.binary_input = binary_input + + def _focal_loss_cost(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_query, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + paddle.Tensor: cls_cost value with weight + """ + if gt_labels.size == 0: + return cls_pred.sum() * 0 + cls_pred = F.sigmoid(cls_pred) + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = paddle.index_select( + pos_cost, gt_labels, axis=1) - paddle.index_select( + neg_cost, gt_labels, axis=1) + return cls_cost * self.weight + + def _mask_focal_loss_cost(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits + in shape (num_query, d1, ..., dn), dtype=paddle.float32. + gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn), + dtype=paddle.long. Labels should be binary. + + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1) + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + cls_pred = F.sigmoid(cls_pred) + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = paddle.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + paddle.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost / n * self.weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits. + gt_labels (Tensor)): Labels. + + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_query, num_gt). + """ + if self.binary_input: + return self._mask_focal_loss_cost(cls_pred, gt_labels) + else: + return self._focal_loss_cost(cls_pred, gt_labels) diff --git a/ppdet/modeling/backbones/resnet.py b/ppdet/modeling/backbones/resnet.py index 6f8eb0b89ccd3cab1d0a08b7fd2a41e6332c3055..3b9508c49f932ffa34f53a946224ed8d7a3ae564 100755 --- a/ppdet/modeling/backbones/resnet.py +++ b/ppdet/modeling/backbones/resnet.py @@ -285,36 +285,6 @@ class BottleNeck(nn.Layer): # ResNeXt width = int(ch_out * (base_width / 64.)) * groups - self.shortcut = shortcut - if not shortcut: - if variant == 'd' and stride == 2: - self.short = nn.Sequential() - self.short.add_sublayer( - 'pool', - nn.AvgPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True)) - self.short.add_sublayer( - 'conv', - ConvNormLayer( - ch_in=ch_in, - ch_out=ch_out * self.expansion, - filter_size=1, - stride=1, - norm_type=norm_type, - norm_decay=norm_decay, - freeze_norm=freeze_norm, - lr=lr)) - else: - self.short = ConvNormLayer( - ch_in=ch_in, - ch_out=ch_out * self.expansion, - filter_size=1, - stride=stride, - norm_type=norm_type, - norm_decay=norm_decay, - freeze_norm=freeze_norm, - lr=lr) - self.branch2a = ConvNormLayer( ch_in=ch_in, ch_out=width, @@ -351,6 +321,36 @@ class BottleNeck(nn.Layer): freeze_norm=freeze_norm, lr=lr) + self.shortcut = shortcut + if not shortcut: + if variant == 'd' and stride == 2: + self.short = nn.Sequential() + self.short.add_sublayer( + 'pool', + nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True)) + self.short.add_sublayer( + 'conv', + ConvNormLayer( + ch_in=ch_in, + ch_out=ch_out * self.expansion, + filter_size=1, + stride=1, + norm_type=norm_type, + norm_decay=norm_decay, + freeze_norm=freeze_norm, + lr=lr)) + else: + self.short = ConvNormLayer( + ch_in=ch_in, + ch_out=ch_out * self.expansion, + filter_size=1, + stride=stride, + norm_type=norm_type, + norm_decay=norm_decay, + freeze_norm=freeze_norm, + lr=lr) + self.std_senet = std_senet if self.std_senet: self.se = SELayer(ch_out * self.expansion) diff --git a/ppdet/modeling/backbones/vision_transformer.py b/ppdet/modeling/backbones/vision_transformer.py index 825724fa4b58319550a4f3e54c9a0d7d73183d3c..a21eefc7aca0d2a5fe0bfa94eddf007612f5f464 100644 --- a/ppdet/modeling/backbones/vision_transformer.py +++ b/ppdet/modeling/backbones/vision_transformer.py @@ -284,9 +284,9 @@ class RelativePositionBias(nn.Layer): def forward(self): relative_position_bias = \ - self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([ self.window_size[0] * self.window_size[1] + 1, - self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH return relative_position_bias.transpose((2, 0, 1)) # nH, Wh*Ww, Wh*Ww diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 9cceb268a7d32b8746b9c3995323b01c60cae99a..07df124cd3aeeb2b77910cee115700aed1234632 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -67,3 +67,4 @@ from .yolof_head import * from .ppyoloe_contrast_head import * from .centertrack_head import * from .sparse_roi_head import * +from .petr_head import * diff --git a/ppdet/modeling/heads/petr_head.py b/ppdet/modeling/heads/petr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..90760c665157b658f776e2ff9f7fbef0b525a006 --- /dev/null +++ b/ppdet/modeling/heads/petr_head.py @@ -0,0 +1,1161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/dense_heads/petr_head.py +""" +import copy +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register +import paddle.distributed as dist + +from ..transformers.petr_transformer import inverse_sigmoid, masked_fill +from ..initializer import constant_, normal_ + +__all__ = ["PETRHead"] + +from functools import partial + + +def bias_init_with_prob(prior_prob: float) -> float: + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + res = tuple(map(list, zip(*map_results))) + return res + + +def reduce_mean(tensor): + """"Obtain the mean of tensor on different GPUs.""" + if not (dist.get_world_size() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce( + tensor.divide( + paddle.to_tensor( + dist.get_world_size(), dtype='float32')), + op=dist.ReduceOp.SUM) + return tensor + + +def gaussian_radius(det_size, min_overlap=0.7): + """calculate gaussian radius according to object size. + """ + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = paddle.sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = paddle.sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = paddle.sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + return min(r1, r2, r3) + + +def gaussian2D(shape, sigma=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y = paddle.arange(-m, m + 1, dtype="float32")[:, None] + x = paddle.arange(-n, n + 1, dtype="float32")[None, :] + # y, x = np.ogrid[-m:m + 1, -n:n + 1] + + h = paddle.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(np.float32).eps * h.max()] = 0 + return h + + +def draw_umich_gaussian(heatmap, center, radius, k=1): + diameter = 2 * radius + 1 + gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) + gaussian = paddle.to_tensor(gaussian, dtype=heatmap.dtype) + + x, y = int(center[0]), int(center[1]) + radius = int(radius) + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left: + radius + right] + # assert masked_gaussian.equal(1).float().sum() == 1 + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + heatmap[y - top:y + bottom, x - left:x + right] = paddle.maximum( + masked_heatmap, masked_gaussian * k) + return heatmap + + +@register +class PETRHead(nn.Layer): + """Head of `End-to-End Multi-Person Pose Estimation with Transformers`. + + Args: + num_classes (int): Number of categories excluding the background. + in_channels (int): Number of channels in the input feature map. + num_query (int): Number of query in Transformer. + num_kpt_fcs (int, optional): Number of fully-connected layers used in + `FFN`, which is then used for the keypoint regression head. + Default 2. + transformer (obj:`mmcv.ConfigDict`|dict): ConfigDict is used for + building the Encoder and Decoder. Default: None. + sync_cls_avg_factor (bool): Whether to sync the avg_factor of + all ranks. Default to False. + positional_encoding (obj:`mmcv.ConfigDict`|dict): + Config for position encoding. + loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the + classification loss. Default `CrossEntropyLoss`. + loss_kpt (obj:`mmcv.ConfigDict`|dict): Config of the + regression loss. Default `L1Loss`. + loss_oks (obj:`mmcv.ConfigDict`|dict): Config of the + regression oks loss. Default `OKSLoss`. + loss_hm (obj:`mmcv.ConfigDict`|dict): Config of the + regression heatmap loss. Default `NegLoss`. + as_two_stage (bool) : Whether to generate the proposal from + the outputs of encoder. + with_kpt_refine (bool): Whether to refine the reference points + in the decoder. Defaults to True. + test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of + transformer head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + __inject__ = [ + "transformer", "positional_encoding", "assigner", "sampler", "loss_cls", + "loss_kpt", "loss_oks", "loss_hm", "loss_kpt_rpn", "loss_kpt_refine", + "loss_oks_refine" + ] + + def __init__(self, + num_classes, + in_channels, + num_query=100, + num_kpt_fcs=2, + num_keypoints=17, + transformer=None, + sync_cls_avg_factor=True, + positional_encoding='SinePositionalEncoding', + loss_cls='FocalLoss', + loss_kpt='L1Loss', + loss_oks='OKSLoss', + loss_hm='CenterFocalLoss', + with_kpt_refine=True, + assigner='PoseHungarianAssigner', + sampler='PseudoSampler', + loss_kpt_rpn='L1Loss', + loss_kpt_refine='L1Loss', + loss_oks_refine='opera.OKSLoss', + test_cfg=dict(max_per_img=100), + init_cfg=None, + **kwargs): + # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, + # since it brings inconvenience when the initialization of + # `AnchorFreeHead` is called. + super().__init__() + self.bg_cls_weight = 0 + self.sync_cls_avg_factor = sync_cls_avg_factor + self.assigner = assigner + self.sampler = sampler + self.num_query = num_query + self.num_classes = num_classes + self.in_channels = in_channels + self.num_kpt_fcs = num_kpt_fcs + self.test_cfg = test_cfg + self.fp16_enabled = False + self.as_two_stage = transformer.as_two_stage + self.with_kpt_refine = with_kpt_refine + self.num_keypoints = num_keypoints + self.loss_cls = loss_cls + self.loss_kpt = loss_kpt + self.loss_kpt_rpn = loss_kpt_rpn + self.loss_kpt_refine = loss_kpt_refine + self.loss_oks = loss_oks + self.loss_oks_refine = loss_oks_refine + self.loss_hm = loss_hm + if self.loss_cls.use_sigmoid: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + self.positional_encoding = positional_encoding + self.transformer = transformer + self.embed_dims = self.transformer.embed_dims + # assert 'num_feats' in positional_encoding + num_feats = positional_encoding.num_pos_feats + assert num_feats * 2 == self.embed_dims, 'embed_dims should' \ + f' be exactly 2 times of num_feats. Found {self.embed_dims}' \ + f' and {num_feats}.' + self._init_layers() + self.init_weights() + + def _init_layers(self): + """Initialize classification branch and keypoint branch of head.""" + + fc_cls = nn.Linear(self.embed_dims, self.cls_out_channels) + + kpt_branch = [] + kpt_branch.append(nn.Linear(self.embed_dims, 512)) + kpt_branch.append(nn.ReLU()) + for _ in range(self.num_kpt_fcs): + kpt_branch.append(nn.Linear(512, 512)) + kpt_branch.append(nn.ReLU()) + kpt_branch.append(nn.Linear(512, 2 * self.num_keypoints)) + kpt_branch = nn.Sequential(*kpt_branch) + + def _get_clones(module, N): + return nn.LayerList([copy.deepcopy(module) for i in range(N)]) + + # last kpt_branch is used to generate proposal from + # encode feature map when as_two_stage is True. + num_pred = (self.transformer.decoder.num_layers + 1) if \ + self.as_two_stage else self.transformer.decoder.num_layers + + if self.with_kpt_refine: + self.cls_branches = _get_clones(fc_cls, num_pred) + self.kpt_branches = _get_clones(kpt_branch, num_pred) + else: + self.cls_branches = nn.LayerList([fc_cls for _ in range(num_pred)]) + self.kpt_branches = nn.LayerList( + [kpt_branch for _ in range(num_pred)]) + + self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2) + + refine_kpt_branch = [] + for _ in range(self.num_kpt_fcs): + refine_kpt_branch.append( + nn.Linear(self.embed_dims, self.embed_dims)) + refine_kpt_branch.append(nn.ReLU()) + refine_kpt_branch.append(nn.Linear(self.embed_dims, 2)) + refine_kpt_branch = nn.Sequential(*refine_kpt_branch) + if self.with_kpt_refine: + num_pred = self.transformer.refine_decoder.num_layers + self.refine_kpt_branches = _get_clones(refine_kpt_branch, num_pred) + self.fc_hm = nn.Linear(self.embed_dims, self.num_keypoints) + + def init_weights(self): + """Initialize weights of the PETR head.""" + self.transformer.init_weights() + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + constant_(m.bias, bias_init) + for m in self.kpt_branches: + constant_(m[-1].bias, 0) + # initialization of keypoint refinement branch + if self.with_kpt_refine: + for m in self.refine_kpt_branches: + constant_(m[-1].bias, 0) + # initialize bias for heatmap prediction + bias_init = bias_init_with_prob(0.1) + normal_(self.fc_hm.weight, std=0.01) + constant_(self.fc_hm.bias, bias_init) + + def forward(self, mlvl_feats, img_metas): + """Forward function. + + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 4D-tensor with shape + (N, C, H, W). + img_metas (list[dict]): List of image information. + + Returns: + outputs_classes (Tensor): Outputs from the classification head, + shape [nb_dec, bs, num_query, cls_out_channels]. Note + cls_out_channels should include background. + outputs_kpts (Tensor): Sigmoid outputs from the regression + head with normalized coordinate format (cx, cy, w, h). + Shape [nb_dec, bs, num_query, K*2]. + enc_outputs_class (Tensor): The score of each point on encode + feature map, has shape (N, h*w, num_class). Only when + as_two_stage is Ture it would be returned, otherwise + `None` would be returned. + enc_outputs_kpt (Tensor): The proposal generate from the + encode feature map, has shape (N, h*w, K*2). Only when + as_two_stage is Ture it would be returned, otherwise + `None` would be returned. + """ + + batch_size = mlvl_feats[0].shape[0] + input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + img_masks = paddle.zeros( + (batch_size, input_img_h, input_img_w), dtype=mlvl_feats[0].dtype) + for img_id in range(batch_size): + img_h, img_w, _ = img_metas[img_id]['img_shape'] + img_masks[img_id, :img_h, :img_w] = 1 + + mlvl_masks = [] + mlvl_positional_encodings = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate( + img_masks[None], size=feat.shape[-2:]).squeeze(0)) + mlvl_positional_encodings.append( + self.positional_encoding(mlvl_masks[-1]).transpose( + [0, 3, 1, 2])) + + query_embeds = self.query_embedding.weight + hs, init_reference, inter_references, \ + enc_outputs_class, enc_outputs_kpt, hm_proto, memory = \ + self.transformer( + mlvl_feats, + mlvl_masks, + query_embeds, + mlvl_positional_encodings, + kpt_branches=self.kpt_branches \ + if self.with_kpt_refine else None, # noqa:E501 + cls_branches=self.cls_branches \ + if self.as_two_stage else None # noqa:E501 + ) + + outputs_classes = [] + outputs_kpts = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp_kpt = self.kpt_branches[lvl](hs[lvl]) + assert reference.shape[-1] == self.num_keypoints * 2 + tmp_kpt += reference + outputs_kpt = F.sigmoid(tmp_kpt) + outputs_classes.append(outputs_class) + outputs_kpts.append(outputs_kpt) + + outputs_classes = paddle.stack(outputs_classes) + outputs_kpts = paddle.stack(outputs_kpts) + + if hm_proto is not None: + # get heatmap prediction (training phase) + hm_memory, hm_mask = hm_proto + hm_pred = self.fc_hm(hm_memory) + hm_proto = (hm_pred.transpose((0, 3, 1, 2)), hm_mask) + + if self.as_two_stage: + return outputs_classes, outputs_kpts, \ + enc_outputs_class, F.sigmoid(enc_outputs_kpt), \ + hm_proto, memory, mlvl_masks + else: + raise RuntimeError('only "as_two_stage=True" is supported.') + + def forward_refine(self, memory, mlvl_masks, refine_targets, losses, + img_metas): + """Forward function. + + Args: + mlvl_masks (tuple[Tensor]): The key_padding_mask from + different level used for encoder and decoder, + each is a 3D-tensor with shape (bs, H, W). + losses (dict[str, Tensor]): A dictionary of loss components. + img_metas (list[dict]): List of image information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + kpt_preds, kpt_targets, area_targets, kpt_weights = refine_targets + pos_inds = kpt_weights.sum(-1) > 0 + if not pos_inds.any(): + pos_kpt_preds = paddle.zeros_like(kpt_preds[:1]) + pos_img_inds = paddle.zeros([1], dtype="int64") + else: + pos_kpt_preds = kpt_preds[pos_inds] + pos_img_inds = (pos_inds.nonzero() / + self.num_query).squeeze(1).astype("int64") + hs, init_reference, inter_references = self.transformer.forward_refine( + mlvl_masks, + memory, + pos_kpt_preds.detach(), + pos_img_inds, + kpt_branches=self.refine_kpt_branches + if self.with_kpt_refine else None, # noqa:E501 + ) + + outputs_kpts = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + tmp_kpt = self.refine_kpt_branches[lvl](hs[lvl]) + assert reference.shape[-1] == 2 + tmp_kpt += reference + outputs_kpt = F.sigmoid(tmp_kpt) + outputs_kpts.append(outputs_kpt) + outputs_kpts = paddle.stack(outputs_kpts) + + if not self.training: + return outputs_kpts + + num_valid_kpt = paddle.clip( + reduce_mean(kpt_weights.sum()), min=1).item() + num_total_pos = paddle.to_tensor( + [outputs_kpts.shape[1]], dtype=kpt_weights.dtype) + num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item() + + if not pos_inds.any(): + for i, kpt_refine_preds in enumerate(outputs_kpts): + loss_kpt = loss_oks = kpt_refine_preds.sum() * 0 + losses[f'd{i}.loss_kpt_refine'] = loss_kpt + losses[f'd{i}.loss_oks_refine'] = loss_oks + continue + return losses + + batch_size = mlvl_masks[0].shape[0] + factors = [] + for img_id in range(batch_size): + img_h, img_w, _ = img_metas[img_id]['img_shape'] + factor = paddle.to_tensor( + [img_w, img_h, img_w, img_h], + dtype="float32").squeeze(-1).unsqueeze(0).tile( + (self.num_query, 1)) + factors.append(factor) + factors = paddle.concat(factors, 0) + factors = factors[pos_inds][:, :2].tile((1, kpt_preds.shape[-1] // 2)) + + pos_kpt_weights = kpt_weights[pos_inds] + pos_kpt_targets = kpt_targets[pos_inds] + pos_kpt_targets_scaled = pos_kpt_targets * factors + pos_areas = area_targets[pos_inds] + pos_valid = kpt_weights[pos_inds][:, 0::2] + for i, kpt_refine_preds in enumerate(outputs_kpts): + if not pos_inds.any(): + print("refine kpt and oks skip") + loss_kpt = loss_oks = kpt_refine_preds.sum() * 0 + losses[f'd{i}.loss_kpt_refine'] = loss_kpt + losses[f'd{i}.loss_oks_refine'] = loss_oks + continue + + # kpt L1 Loss + pos_refine_preds = kpt_refine_preds.reshape( + (kpt_refine_preds.shape[0], -1)) + loss_kpt = self.loss_kpt_refine( + pos_refine_preds, + pos_kpt_targets, + pos_kpt_weights, + avg_factor=num_valid_kpt) + losses[f'd{i}.loss_kpt_refine'] = loss_kpt + # kpt oks loss + pos_refine_preds_scaled = pos_refine_preds * factors + assert (pos_areas > 0).all() + loss_oks = self.loss_oks_refine( + pos_refine_preds_scaled, + pos_kpt_targets_scaled, + pos_valid, + pos_areas, + avg_factor=num_total_pos) + losses[f'd{i}.loss_oks_refine'] = loss_oks + return losses + + # over-write because img_metas are needed as inputs for bbox_head. + def forward_train(self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_keypoints=None, + gt_areas=None, + gt_bboxes_ignore=None, + proposal_cfg=None, + **kwargs): + """Forward function for training mode. + + Args: + x (list[Tensor]): Features from backbone. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (list[Tensor]): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (list[Tensor]): Ground truth labels of each box, + shape (num_gts,). + gt_keypoints (list[Tensor]): Ground truth keypoints of the image, + shape (num_gts, K*3). + gt_areas (list[Tensor]): Ground truth mask areas of each box, + shape (num_gts,). + gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert proposal_cfg is None, '"proposal_cfg" must be None' + outs = self(x, img_metas) + memory, mlvl_masks = outs[-2:] + outs = outs[:-2] + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, gt_keypoints, gt_areas, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, gt_keypoints, gt_areas, + img_metas) + losses_and_targets = self.loss( + *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + # losses = losses_and_targets + losses, refine_targets = losses_and_targets + # get pose refinement loss + losses = self.forward_refine(memory, mlvl_masks, refine_targets, losses, + img_metas) + return losses + + def loss(self, + all_cls_scores, + all_kpt_preds, + enc_cls_scores, + enc_kpt_preds, + enc_hm_proto, + gt_bboxes_list, + gt_labels_list, + gt_keypoints_list, + gt_areas_list, + img_metas, + gt_bboxes_ignore=None): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_kpt_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (x_{i}, y_{i}) and shape + [nb_dec, bs, num_query, K*2]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map, has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_kpt_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, K*2). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_keypoints_list (list[Tensor]): Ground truth keypoints for each + image with shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v, + ..., p^{K}_x, p^{K}_y, p^{K}_v] format. + gt_areas_list (list[Tensor]): Ground truth mask areas for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert gt_bboxes_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + f'for gt_bboxes_ignore setting to None.' + + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_keypoints_list = [ + gt_keypoints_list for _ in range(num_dec_layers) + ] + all_gt_areas_list = [gt_areas_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + + losses_cls, losses_kpt, losses_oks, kpt_preds_list, kpt_targets_list, \ + area_targets_list, kpt_weights_list = multi_apply( + self.loss_single, all_cls_scores, all_kpt_preds, + all_gt_labels_list, all_gt_keypoints_list, + all_gt_areas_list, img_metas_list) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + binary_labels_list = [ + paddle.zeros_like(gt_labels_list[i]) + for i in range(len(img_metas)) + ] + enc_loss_cls, enc_losses_kpt = \ + self.loss_single_rpn( + enc_cls_scores, enc_kpt_preds, binary_labels_list, + gt_keypoints_list, gt_areas_list, img_metas) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_kpt'] = enc_losses_kpt + + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_kpt'] = losses_kpt[-1] + loss_dict['loss_oks'] = losses_oks[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_kpt_i, loss_oks_i in zip( + losses_cls[:-1], losses_kpt[:-1], losses_oks[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_kpt'] = loss_kpt_i + loss_dict[f'd{num_dec_layer}.loss_oks'] = loss_oks_i + num_dec_layer += 1 + + # losses of heatmap generated from P3 feature map + hm_pred, hm_mask = enc_hm_proto + loss_hm = self.loss_heatmap(hm_pred, hm_mask, gt_keypoints_list, + gt_labels_list, gt_bboxes_list) + loss_dict['loss_hm'] = loss_hm + + return loss_dict, (kpt_preds_list[-1], kpt_targets_list[-1], + area_targets_list[-1], kpt_weights_list[-1]) + + def loss_heatmap(self, hm_pred, hm_mask, gt_keypoints, gt_labels, + gt_bboxes): + assert hm_pred.shape[-2:] == hm_mask.shape[-2:] + num_img, _, h, w = hm_pred.shape + # placeholder of heatmap target (Gaussian distribution) + hm_target = paddle.zeros(hm_pred.shape, hm_pred.dtype) + for i, (gt_label, gt_bbox, gt_keypoint + ) in enumerate(zip(gt_labels, gt_bboxes, gt_keypoints)): + if gt_label.shape[0] == 0: + continue + gt_keypoint = gt_keypoint.reshape((gt_keypoint.shape[0], -1, + 3)).clone() + gt_keypoint[..., :2] /= 8 + + assert gt_keypoint[..., 0].max() <= w + 0.5 # new coordinate system + assert gt_keypoint[..., 1].max() <= h + 0.5 # new coordinate system + gt_bbox /= 8 + gt_w = gt_bbox[:, 2] - gt_bbox[:, 0] + gt_h = gt_bbox[:, 3] - gt_bbox[:, 1] + for j in range(gt_label.shape[0]): + # get heatmap radius + kp_radius = paddle.clip( + paddle.floor( + gaussian_radius( + (gt_h[j], gt_w[j]), min_overlap=0.9)), + min=0, + max=3) + for k in range(self.num_keypoints): + if gt_keypoint[j, k, 2] > 0: + gt_kp = gt_keypoint[j, k, :2] + gt_kp_int = paddle.floor(gt_kp) + hm_target[i, k] = draw_umich_gaussian( + hm_target[i, k], gt_kp_int, kp_radius) + # compute heatmap loss + hm_pred = paddle.clip( + F.sigmoid(hm_pred), min=1e-4, max=1 - 1e-4) # refer to CenterNet + loss_hm = self.loss_hm( + hm_pred, + hm_target.detach(), + mask=~hm_mask.astype("bool").unsqueeze(1)) + return loss_hm + + def loss_single(self, cls_scores, kpt_preds, gt_labels_list, + gt_keypoints_list, gt_areas_list, img_metas): + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + kpt_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (x_{i}, y_{i}) and + shape [bs, num_query, K*2]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_keypoints_list (list[Tensor]): Ground truth keypoints for each + image with shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v, + ..., p^{K}_x, p^{K}_y, p^{K}_v] format. + gt_areas_list (list[Tensor]): Ground truth mask areas for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.shape[0] + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + kpt_preds_list = [kpt_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list, + gt_labels_list, gt_keypoints_list, + gt_areas_list, img_metas) + (labels_list, label_weights_list, kpt_targets_list, kpt_weights_list, + area_targets_list, num_total_pos, num_total_neg) = cls_reg_targets + labels = paddle.concat(labels_list, 0) + label_weights = paddle.concat(label_weights_list, 0) + kpt_targets = paddle.concat(kpt_targets_list, 0) + kpt_weights = paddle.concat(kpt_weights_list, 0) + area_targets = paddle.concat(area_targets_list, 0) + + # classification loss + cls_scores = cls_scores.reshape((-1, self.cls_out_channels)) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + paddle.to_tensor( + [cls_avg_factor], dtype=cls_scores.dtype)) + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt keypoints accross all gpus, for + # normalization purposes + num_total_pos = paddle.to_tensor([num_total_pos], dtype=loss_cls.dtype) + num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale keypoints + factors = [] + for img_meta, kpt_pred in zip(img_metas, kpt_preds): + img_h, img_w, _ = img_meta['img_shape'] + factor = paddle.to_tensor( + [img_w, img_h, img_w, img_h], + dtype=kpt_pred.dtype).squeeze().unsqueeze(0).tile( + (kpt_pred.shape[0], 1)) + factors.append(factor) + factors = paddle.concat(factors, 0) + + # keypoint regression loss + kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1])) + num_valid_kpt = paddle.clip( + reduce_mean(kpt_weights.sum()), min=1).item() + # assert num_valid_kpt == (kpt_targets>0).sum().item() + loss_kpt = self.loss_kpt( + kpt_preds, + kpt_targets.detach(), + kpt_weights.detach(), + avg_factor=num_valid_kpt) + + # keypoint oks loss + pos_inds = kpt_weights.sum(-1) > 0 + if not pos_inds.any(): + loss_oks = kpt_preds.sum() * 0 + else: + factors = factors[pos_inds][:, :2].tile(( + (1, kpt_preds.shape[-1] // 2))) + pos_kpt_preds = kpt_preds[pos_inds] * factors + pos_kpt_targets = kpt_targets[pos_inds] * factors + pos_areas = area_targets[pos_inds] + pos_valid = kpt_weights[pos_inds][..., 0::2] + assert (pos_areas > 0).all() + loss_oks = self.loss_oks( + pos_kpt_preds, + pos_kpt_targets, + pos_valid, + pos_areas, + avg_factor=num_total_pos) + return loss_cls, loss_kpt, loss_oks, kpt_preds, kpt_targets, \ + area_targets, kpt_weights + + def get_targets(self, cls_scores_list, kpt_preds_list, gt_labels_list, + gt_keypoints_list, gt_areas_list, img_metas): + """Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + kpt_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (x_{i}, y_{i}) and shape [num_query, K*2]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_keypoints_list (list[Tensor]): Ground truth keypoints for each + image with shape (num_gts, K*3). + gt_areas_list (list[Tensor]): Ground truth mask areas for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all + images. + - kpt_targets_list (list[Tensor]): Keypoint targets for all + images. + - kpt_weights_list (list[Tensor]): Keypoint weights for all + images. + - area_targets_list (list[Tensor]): area targets for all + images. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + (labels_list, label_weights_list, kpt_targets_list, kpt_weights_list, + area_targets_list, pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, cls_scores_list, kpt_preds_list, + gt_labels_list, gt_keypoints_list, gt_areas_list, img_metas) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, kpt_targets_list, + kpt_weights_list, area_targets_list, num_total_pos, + num_total_neg) + + def _get_target_single(self, cls_score, kpt_pred, gt_labels, gt_keypoints, + gt_areas, img_meta): + """Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + kpt_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (x_{i}, y_{i}) and + shape [num_query, K*2]. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + gt_keypoints (Tensor): Ground truth keypoints for one image with + shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v, ..., \ + p^{K}_x, p^{K}_y, p^{K}_v] format. + gt_areas (Tensor): Ground truth mask areas for one image + with shape (num_gts, ). + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor): Label weights of each image. + - kpt_targets (Tensor): Keypoint targets of each image. + - kpt_weights (Tensor): Keypoint weights of each image. + - area_targets (Tensor): Area targets of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + num_bboxes = kpt_pred.shape[0] + # assigner and sampler + assign_result = self.assigner.assign(cls_score, kpt_pred, gt_labels, + gt_keypoints, gt_areas, img_meta) + sampling_result = self.sampler.sample(assign_result, kpt_pred, + gt_keypoints) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = paddle.full((num_bboxes, ), self.num_classes, dtype="int64") + label_weights = paddle.ones((num_bboxes, ), dtype=gt_labels.dtype) + kpt_targets = paddle.zeros_like(kpt_pred) + kpt_weights = paddle.zeros_like(kpt_pred) + area_targets = paddle.zeros((kpt_pred.shape[0], ), dtype=kpt_pred.dtype) + + if pos_inds.size == 0: + return (labels, label_weights, kpt_targets, kpt_weights, + area_targets, pos_inds, neg_inds) + + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds][ + ..., 0].astype("int64") + + img_h, img_w, _ = img_meta['img_shape'] + # keypoint targets + pos_gt_kpts = gt_keypoints[sampling_result.pos_assigned_gt_inds] + pos_gt_kpts = pos_gt_kpts.reshape( + (len(sampling_result.pos_assigned_gt_inds), -1, 3)) + valid_idx = pos_gt_kpts[:, :, 2] > 0 + pos_kpt_weights = kpt_weights[pos_inds].reshape( + (pos_gt_kpts.shape[0], kpt_weights.shape[-1] // 2, 2)) + # pos_kpt_weights[valid_idx][...] = 1.0 + pos_kpt_weights = masked_fill(pos_kpt_weights, + valid_idx.unsqueeze(-1), 1.0) + kpt_weights[pos_inds] = pos_kpt_weights.reshape( + (pos_kpt_weights.shape[0], kpt_pred.shape[-1])) + + factor = paddle.to_tensor( + [img_w, img_h], dtype=kpt_pred.dtype).squeeze().unsqueeze(0) + pos_gt_kpts_normalized = pos_gt_kpts[..., :2] + pos_gt_kpts_normalized[..., 0] = pos_gt_kpts_normalized[..., 0] / \ + factor[:, 0:1] + pos_gt_kpts_normalized[..., 1] = pos_gt_kpts_normalized[..., 1] / \ + factor[:, 1:2] + kpt_targets[pos_inds] = pos_gt_kpts_normalized.reshape( + (pos_gt_kpts.shape[0], kpt_pred.shape[-1])) + + pos_gt_areas = gt_areas[sampling_result.pos_assigned_gt_inds][..., 0] + area_targets[pos_inds] = pos_gt_areas + + return (labels, label_weights, kpt_targets, kpt_weights, area_targets, + pos_inds, neg_inds) + + def loss_single_rpn(self, cls_scores, kpt_preds, gt_labels_list, + gt_keypoints_list, gt_areas_list, img_metas): + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + kpt_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (x_{i}, y_{i}) and + shape [bs, num_query, K*2]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_keypoints_list (list[Tensor]): Ground truth keypoints for each + image with shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v, + ..., p^{K}_x, p^{K}_y, p^{K}_v] format. + gt_areas_list (list[Tensor]): Ground truth mask areas for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.shape[0] + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + kpt_preds_list = [kpt_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list, + gt_labels_list, gt_keypoints_list, + gt_areas_list, img_metas) + (labels_list, label_weights_list, kpt_targets_list, kpt_weights_list, + area_targets_list, num_total_pos, num_total_neg) = cls_reg_targets + labels = paddle.concat(labels_list, 0) + label_weights = paddle.concat(label_weights_list, 0) + kpt_targets = paddle.concat(kpt_targets_list, 0) + kpt_weights = paddle.concat(kpt_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape((-1, self.cls_out_channels)) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + paddle.to_tensor( + [cls_avg_factor], dtype=cls_scores.dtype)) + cls_avg_factor = max(cls_avg_factor, 1) + + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt keypoints accross all gpus, for + # normalization purposes + # num_total_pos = loss_cls.to_tensor([num_total_pos]) + # num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item() + + # keypoint regression loss + kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1])) + num_valid_kpt = paddle.clip( + reduce_mean(kpt_weights.sum()), min=1).item() + # assert num_valid_kpt == (kpt_targets>0).sum().item() + loss_kpt = self.loss_kpt_rpn( + kpt_preds, kpt_targets, kpt_weights, avg_factor=num_valid_kpt) + + return loss_cls, loss_kpt + + def get_bboxes(self, + all_cls_scores, + all_kpt_preds, + enc_cls_scores, + enc_kpt_preds, + hm_proto, + memory, + mlvl_masks, + img_metas, + rescale=False): + """Transform network outputs for a batch into bbox predictions. + + Args: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_kpt_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (x_{i}, y_{i}) and shape + [nb_dec, bs, num_query, K*2]. + enc_cls_scores (Tensor): Classification scores of points on + encode feature map, has shape (N, h*w, num_classes). + Only be passed when as_two_stage is True, otherwise is None. + enc_kpt_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, K*2). Only be + passed when as_two_stage is True, otherwise is None. + img_metas (list[dict]): Meta information of each image. + rescale (bool, optional): If True, return boxes in original + image space. Defalut False. + + Returns: + list[list[Tensor, Tensor]]: Each item in result_list is 3-tuple. + The first item is an (n, 5) tensor, where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of + the corresponding box. The third item is an (n, K, 3) tensor + with [p^{1}_x, p^{1}_y, p^{1}_v, ..., p^{K}_x, p^{K}_y, + p^{K}_v] format. + """ + cls_scores = all_cls_scores[-1] + kpt_preds = all_kpt_preds[-1] + + result_list = [] + for img_id in range(len(img_metas)): + cls_score = cls_scores[img_id] + kpt_pred = kpt_preds[img_id] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + # TODO: only support single image test + # memory_i = memory[:, img_id, :] + # mlvl_mask = mlvl_masks[img_id] + proposals = self._get_bboxes_single(cls_score, kpt_pred, img_shape, + scale_factor, memory, + mlvl_masks, rescale) + result_list.append(proposals) + return result_list + + def _get_bboxes_single(self, + cls_score, + kpt_pred, + img_shape, + scale_factor, + memory, + mlvl_masks, + rescale=False): + """Transform outputs from the last decoder layer into bbox predictions + for each image. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_query, cls_out_channels]. + kpt_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (x_{i}, y_{i}) and + shape [num_query, K*2]. + img_shape (tuple[int]): Shape of input image, (height, width, 3). + scale_factor (ndarray, optional): Scale factor of the image arange + as (w_scale, h_scale, w_scale, h_scale). + rescale (bool, optional): If True, return boxes in original image + space. Default False. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. + + - det_bboxes: Predicted bboxes with shape [num_query, 5], + where the first 4 columns are bounding box positions + (tl_x, tl_y, br_x, br_y) and the 5-th column are scores + between 0 and 1. + - det_labels: Predicted labels of the corresponding box with + shape [num_query]. + - det_kpts: Predicted keypoints with shape [num_query, K, 3]. + """ + assert len(cls_score) == len(kpt_pred) + max_per_img = self.test_cfg.get('max_per_img', self.num_query) + # exclude background + if self.loss_cls.use_sigmoid: + cls_score = F.sigmoid(cls_score) + scores, indexs = cls_score.reshape([-1]).topk(max_per_img) + det_labels = indexs % self.num_classes + bbox_index = indexs // self.num_classes + kpt_pred = kpt_pred[bbox_index] + else: + scores, det_labels = F.softmax(cls_score, axis=-1)[..., :-1].max(-1) + scores, bbox_index = scores.topk(max_per_img) + kpt_pred = kpt_pred[bbox_index] + det_labels = det_labels[bbox_index] + + # ----- results after pose decoder ----- + # det_kpts = kpt_pred.reshape((kpt_pred.shape[0], -1, 2)) + + # ----- results after joint decoder (default) ----- + # import time + # start = time.time() + refine_targets = (kpt_pred, None, None, paddle.ones_like(kpt_pred)) + refine_outputs = self.forward_refine(memory, mlvl_masks, refine_targets, + None, None) + # end = time.time() + # print(f'refine time: {end - start:.6f}') + det_kpts = refine_outputs[-1] + + det_kpts[..., 0] = det_kpts[..., 0] * img_shape[1] + det_kpts[..., 1] = det_kpts[..., 1] * img_shape[0] + det_kpts[..., 0].clip_(min=0, max=img_shape[1]) + det_kpts[..., 1].clip_(min=0, max=img_shape[0]) + if rescale: + det_kpts /= paddle.to_tensor( + scale_factor[:2], + dtype=det_kpts.dtype).unsqueeze(0).unsqueeze(0) + + # use circumscribed rectangle box of keypoints as det bboxes + x1 = det_kpts[..., 0].min(axis=1, keepdim=True) + y1 = det_kpts[..., 1].min(axis=1, keepdim=True) + x2 = det_kpts[..., 0].max(axis=1, keepdim=True) + y2 = det_kpts[..., 1].max(axis=1, keepdim=True) + det_bboxes = paddle.concat([x1, y1, x2, y2], axis=1) + det_bboxes = paddle.concat((det_bboxes, scores.unsqueeze(1)), -1) + + det_kpts = paddle.concat( + (det_kpts, paddle.ones( + det_kpts[..., :1].shape, dtype=det_kpts.dtype)), + axis=2) + + return det_bboxes, det_labels, det_kpts + + def simple_test(self, feats, img_metas, rescale=False): + """Test det bboxes without test-time augmentation. + + Args: + feats (tuple[paddle.Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is + 3-tuple. The first item is ``bboxes`` with shape (n, 5), + where 5 represent (tl_x, tl_y, br_x, br_y, score). + The shape of the second tensor in the tuple is ``labels`` + with shape (n,). The third item is ``kpts`` with shape + (n, K, 3), in [p^{1}_x, p^{1}_y, p^{1}_v, p^{K}_x, p^{K}_y, + p^{K}_v] format. + """ + # forward of this head requires img_metas + outs = self.forward(feats, img_metas) + results_list = self.get_bboxes(*outs, img_metas, rescale=rescale) + return results_list + + def get_loss(self, boxes, scores, gt_bbox, gt_class, prior_boxes): + return self.loss(boxes, scores, gt_bbox, gt_class, prior_boxes) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 388be3cc9b93f46cdfb00bf942bf43e8fcfa79fd..16368e81e62d01cae7e628a30bd7c98cb9dcb234 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -1135,7 +1135,7 @@ def _convert_attention_mask(attn_mask, dtype): """ return nn.layer.transformer._convert_attention_mask(attn_mask, dtype) - +@register class MultiHeadAttention(nn.Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and diff --git a/ppdet/modeling/losses/focal_loss.py b/ppdet/modeling/losses/focal_loss.py index 083e1dd3dbd17f849238fcf060416bfdc6765216..b9a64e1bc22d7e69256b311639ceb450c1381798 100644 --- a/ppdet/modeling/losses/focal_loss.py +++ b/ppdet/modeling/losses/focal_loss.py @@ -21,7 +21,7 @@ import paddle.nn.functional as F import paddle.nn as nn from ppdet.core.workspace import register -__all__ = ['FocalLoss'] +__all__ = ['FocalLoss', 'Weighted_FocalLoss'] @register class FocalLoss(nn.Layer): @@ -59,3 +59,80 @@ class FocalLoss(nn.Layer): pred, target, alpha=self.alpha, gamma=self.gamma, reduction=reduction) return loss * self.loss_weight + + +@register +class Weighted_FocalLoss(FocalLoss): + """A wrapper around paddle.nn.functional.sigmoid_focal_loss. + Args: + use_sigmoid (bool): currently only support use_sigmoid=True + alpha (float): parameter alpha in Focal Loss + gamma (float): parameter gamma in Focal Loss + loss_weight (float): final loss will be multiplied by this + """ + def __init__(self, + use_sigmoid=True, + alpha=0.25, + gamma=2.0, + loss_weight=1.0, + reduction="mean"): + super(FocalLoss, self).__init__() + assert use_sigmoid == True, \ + 'Focal Loss only supports sigmoid at the moment' + self.use_sigmoid = use_sigmoid + self.alpha = alpha + self.gamma = gamma + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): + """forward function. + Args: + pred (Tensor): logits of class prediction, of shape (N, num_classes) + target (Tensor): target class label, of shape (N, ) + reduction (str): the way to reduce loss, one of (none, sum, mean) + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + num_classes = pred.shape[1] + target = F.one_hot(target, num_classes + 1).astype(pred.dtype) + target = target[:, :-1].detach() + loss = F.sigmoid_focal_loss( + pred, target, alpha=self.alpha, gamma=self.gamma, + reduction='none') + + if weight is not None: + if weight.shape != loss.shape: + if weight.shape[0] == loss.shape[0]: + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.reshape((-1, 1)) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.reshape((loss.shape[0], -1)) + assert weight.ndim == loss.ndim + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = 1e-10 + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + + return loss * self.loss_weight diff --git a/ppdet/modeling/losses/keypoint_loss.py b/ppdet/modeling/losses/keypoint_loss.py index 9c3c113db36338814643bc3e3d907e62b98e992f..37a24102a85eec227ef3acd05f0814274f275e54 100644 --- a/ppdet/modeling/losses/keypoint_loss.py +++ b/ppdet/modeling/losses/keypoint_loss.py @@ -18,12 +18,13 @@ from __future__ import print_function from itertools import cycle, islice from collections import abc +import numpy as np import paddle import paddle.nn as nn from ppdet.core.workspace import register, serializable -__all__ = ['HrHRNetLoss', 'KeyPointMSELoss'] +__all__ = ['HrHRNetLoss', 'KeyPointMSELoss', 'OKSLoss', 'CenterFocalLoss', 'L1Loss'] @register @@ -226,3 +227,406 @@ def recursive_sum(inputs): if isinstance(inputs, abc.Sequence): return sum([recursive_sum(x) for x in inputs]) return inputs + + +def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas): + if not kpt_gts.astype('bool').any(): + return kpt_preds.sum()*0 + + sigmas = paddle.to_tensor(sigmas, dtype=kpt_preds.dtype) + variances = (sigmas * 2)**2 + + assert kpt_preds.shape[0] == kpt_gts.shape[0] + kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1] // 2, 2)) + kpt_gts = kpt_gts.reshape((-1, kpt_gts.shape[-1] // 2, 2)) + + squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \ + (kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2 + assert (kpt_valids.sum(-1) > 0).all() + squared_distance0 = squared_distance / ( + kpt_areas[:, None] * variances[None, :] * 2) + squared_distance1 = paddle.exp(-squared_distance0) + squared_distance1 = squared_distance1 * kpt_valids + oks = squared_distance1.sum(axis=1) / kpt_valids.sum(axis=1) + + return oks + + +def oks_loss(pred, + target, + weight, + valid=None, + area=None, + linear=False, + sigmas=None, + eps=1e-6, + avg_factor=None, + reduction=None): + """Oks loss. + + Computing the oks loss between a set of predicted poses and target poses. + The loss is calculated as negative log of oks. + + Args: + pred (Tensor): Predicted poses of format (x1, y1, x2, y2, ...), + shape (n, K*2). + target (Tensor): Corresponding gt poses, shape (n, K*2). + linear (bool, optional): If True, use linear scale of loss instead of + log scale. Default: False. + eps (float): Eps to avoid log(0). + + Returns: + Tensor: Loss tensor. + """ + oks = oks_overlaps(pred, target, valid, area, sigmas).clip(min=eps) + if linear: + loss = 1 - oks + else: + loss = -oks.log() + + if weight is not None: + if weight.shape != loss.shape: + if weight.shape[0] == loss.shape[0]: + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.reshape((-1, 1)) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.reshape((loss.shape[0], -1)) + assert weight.ndim == loss.ndim + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = 1e-10 + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + + + return loss + +@register +@serializable +class OKSLoss(nn.Layer): + """OKSLoss. + + Computing the oks loss between a set of predicted poses and target poses. + + Args: + linear (bool): If True, use linear scale of loss instead of log scale. + Default: False. + eps (float): Eps to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + """ + + def __init__(self, + linear=False, + num_keypoints=17, + eps=1e-6, + reduction='mean', + loss_weight=1.0): + super(OKSLoss, self).__init__() + self.linear = linear + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + if num_keypoints == 17: + self.sigmas = np.array([ + .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, + 1.07, .87, .87, .89, .89 + ], dtype=np.float32) / 10.0 + elif num_keypoints == 14: + self.sigmas = np.array([ + .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, + .79, .79 + ]) / 10.0 + else: + raise ValueError(f'Unsupported keypoints number {num_keypoints}') + + def forward(self, + pred, + target, + valid, + area, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + valid (Tensor): The visible flag of the target pose. + area (Tensor): The area of the target pose. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if (weight is not None) and (not paddle.any(weight > 0)) and ( + reduction != 'none'): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # iou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * oks_loss( + pred, + target, + weight, + valid=valid, + area=area, + linear=self.linear, + sigmas=self.sigmas, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +def center_focal_loss(pred, gt, weight=None, mask=None, avg_factor=None, reduction=None): + """Modified focal loss. Exactly the same as CornerNet. + Runs faster and costs a little bit more memory. + + Args: + pred (Tensor): The prediction with shape [bs, c, h, w]. + gt (Tensor): The learning target of the prediction in gaussian + distribution, with shape [bs, c, h, w]. + mask (Tensor): The valid mask. Defaults to None. + """ + if not gt.astype('bool').any(): + return pred.sum()*0 + pos_inds = gt.equal(1).astype('float32') + if mask is None: + neg_inds = gt.less_than(paddle.to_tensor([1], dtype='float32')).astype('float32') + else: + neg_inds = gt.less_than(paddle.to_tensor([1], dtype='float32')).astype('float32') * mask.equal(0).astype('float32') + + neg_weights = paddle.pow(1 - gt, 4) + + loss = 0 + + pos_loss = paddle.log(pred) * paddle.pow(1 - pred, 2) * pos_inds + neg_loss = paddle.log(1 - pred) * paddle.pow(pred, 2) * neg_weights * \ + neg_inds + + num_pos = pos_inds.astype('float32').sum() + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if num_pos == 0: + loss = loss - neg_loss + else: + loss = loss - (pos_loss + neg_loss) / num_pos + + if weight is not None: + if weight.shape != loss.shape: + if weight.shape[0] == loss.shape[0]: + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.reshape((-1, 1)) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.reshape((loss.shape[0], -1)) + assert weight.ndim == loss.ndim + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = 1e-10 + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + + return loss + +@register +@serializable +class CenterFocalLoss(nn.Layer): + """CenterFocalLoss is a variant of focal loss. + + More details can be found in the `paper + `_ + + Args: + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Loss weight of current loss. + """ + + def __init__(self, + reduction='none', + loss_weight=1.0): + super(CenterFocalLoss, self).__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + mask=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction in gaussian + distribution. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + mask (Tensor): The valid mask. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_reg = self.loss_weight * center_focal_loss( + pred, + target, + weight, + mask=mask, + reduction=reduction, + avg_factor=avg_factor) + return loss_reg + +def l1_loss(pred, target, weight=None, reduction='mean', avg_factor=None): + """L1 loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + + Returns: + Tensor: Calculated loss + """ + if not target.astype('bool').any(): + return pred.sum() * 0 + + assert pred.shape == target.shape + loss = paddle.abs(pred - target) + + if weight is not None: + if weight.shape != loss.shape: + if weight.shape[0] == loss.shape[0]: + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.reshape((-1, 1)) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.reshape((loss.shape[0], -1)) + assert weight.ndim == loss.ndim + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = 1e-10 + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + + + return loss + +@register +@serializable +class L1Loss(nn.Layer): + """L1 loss. + + Args: + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of loss. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super(L1Loss, self).__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_bbox = self.loss_weight * l1_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss_bbox + diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 51d367b27d4feb49ddfea03133aa0042ef6d4b7b..478efec98e324b213ad3f822b551f92265d91e25 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -36,3 +36,4 @@ from .es_pan import * from .lc_pan import * from .custom_pan import * from .dilated_encoder import * +from .channel_mapper import * diff --git a/ppdet/modeling/necks/channel_mapper.py b/ppdet/modeling/necks/channel_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6eff3f85476815351e2ec25750949c4cba74da84 --- /dev/null +++ b/ppdet/modeling/necks/channel_mapper.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is base on mmdet: git@github.com:open-mmlab/mmdetection.git +""" +import paddle.nn as nn + +from ppdet.core.workspace import register, serializable +from ..backbones.hrnet import ConvNormLayer +from ..shape_spec import ShapeSpec +from ..initializer import xavier_uniform_, constant_ + +__all__ = ['ChannelMapper'] + + +@register +@serializable +class ChannelMapper(nn.Layer): + """Channel Mapper to reduce/increase channels of backbone features. + + This is used to reduce/increase channels of backbone features. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + kernel_size (int, optional): kernel_size for reducing channels (used + at each scale). Default: 3. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU'). + num_outs (int, optional): Number of output feature maps. There + would be extra_convs when num_outs larger than the length + of in_channels. + init_cfg (dict or list[dict], optional): Initialization config dict. + + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + norm_type="gn", + norm_groups=32, + act='relu', + num_outs=None, + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(ChannelMapper, self).__init__() + assert isinstance(in_channels, list) + self.extra_convs = None + if num_outs is None: + num_outs = len(in_channels) + self.convs = nn.LayerList() + for in_channel in in_channels: + self.convs.append( + ConvNormLayer( + ch_in=in_channel, + ch_out=out_channels, + filter_size=kernel_size, + norm_type='gn', + norm_groups=32, + act=act)) + + if num_outs > len(in_channels): + self.extra_convs = nn.LayerList() + for i in range(len(in_channels), num_outs): + if i == len(in_channels): + in_channel = in_channels[-1] + else: + in_channel = out_channels + self.extra_convs.append( + ConvNormLayer( + ch_in=in_channel, + ch_out=out_channels, + filter_size=3, + stride=2, + norm_type='gn', + norm_groups=32, + act=act)) + self.init_weights() + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.convs) + outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] + if self.extra_convs: + for i in range(len(self.extra_convs)): + if i == 0: + outs.append(self.extra_convs[0](inputs[-1])) + else: + outs.append(self.extra_convs[i](outs[-1])) + return tuple(outs) + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channel, stride=1. / s) + for s in self.spatial_scales + ] + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.rank() > 1: + xavier_uniform_(p) + if hasattr(p, 'bias') and p.bias is not None: + constant_(p.bais) diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index 9be26fc3463c2a4aacb71d7791573f4e4e970124..e55cb0c1de9d62154a93cd8d6a101ef8fe51d356 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -25,3 +25,4 @@ from .matchers import * from .position_encoding import * from .deformable_transformer import * from .dino_transformer import * +from .petr_transformer import * diff --git a/ppdet/modeling/transformers/petr_transformer.py b/ppdet/modeling/transformers/petr_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7859b0df028bf5da7a615c427de4fb0850bfca2e --- /dev/null +++ b/ppdet/modeling/transformers/petr_transformer.py @@ -0,0 +1,1198 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/utils/transformer.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr + +from ppdet.core.workspace import register +from ..layers import MultiHeadAttention, _convert_attention_mask +from .utils import _get_clones +from ..initializer import linear_init_, normal_, constant_, xavier_uniform_ + +__all__ = [ + 'PETRTransformer', 'MultiScaleDeformablePoseAttention', + 'PETR_TransformerDecoderLayer', 'PETR_TransformerDecoder', + 'PETR_DeformableDetrTransformerDecoder', + 'PETR_DeformableTransformerDecoder', 'TransformerEncoderLayer', + 'TransformerEncoder', 'MSDeformableAttention' +] + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clip(min=0, max=1) + x1 = x.clip(min=eps) + x2 = (1 - x).clip(min=eps) + return paddle.log(x1 / x2) + + +@register +class TransformerEncoderLayer(nn.Layer): + __inject__ = ['attn'] + + def __init__(self, + d_model, + attn=None, + nhead=8, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False): + super(TransformerEncoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + self.embed_dims = d_model + + if attn is None: + self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + else: + self.self_attn = attn + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train") + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, src, src_mask=None, pos_embed=None, **kwargs): + residual = src + if self.normalize_before: + src = self.norm1(src) + q = k = self.with_pos_embed(src, pos_embed) + src = self.self_attn(q, k, value=src, attn_mask=src_mask, **kwargs) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +@register +class TransformerEncoder(nn.Layer): + __inject__ = ['encoder_layer'] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.embed_dims = encoder_layer.embed_dims + + def forward(self, src, src_mask=None, pos_embed=None, **kwargs): + output = src + for layer in self.layers: + output = layer( + output, src_mask=src_mask, pos_embed=pos_embed, **kwargs) + + if self.norm is not None: + output = self.norm(output) + + return output + + +@register +class MSDeformableAttention(nn.Layer): + def __init__(self, + embed_dim=256, + num_heads=8, + num_levels=4, + num_points=4, + lr_mult=0.1): + """ + Multi-Scale Deformable Attention Module + """ + super(MSDeformableAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.num_points = num_points + self.total_points = num_heads * num_levels * num_points + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear( + embed_dim, + self.total_points * 2, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult)) + + self.attention_weights = nn.Linear(embed_dim, self.total_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + try: + # use cuda op + print("use deformable_detr_ops in ms_deformable_attn") + from deformable_detr_ops import ms_deformable_attn + except: + # use paddle func + from .utils import deformable_attention_core_func as ms_deformable_attn + self.ms_deformable_attn_core = ms_deformable_attn + + self._reset_parameters() + + def _reset_parameters(self): + # sampling_offsets + constant_(self.sampling_offsets.weight) + thetas = paddle.arange( + self.num_heads, + dtype=paddle.float32) * (2.0 * math.pi / self.num_heads) + grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True) + grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile( + [1, self.num_levels, self.num_points, 1]) + scaling = paddle.arange( + 1, self.num_points + 1, + dtype=paddle.float32).reshape([1, 1, -1, 1]) + grid_init *= scaling + self.sampling_offsets.bias.set_value(grid_init.flatten()) + # attention_weights + constant_(self.attention_weights.weight) + constant_(self.attention_weights.bias) + # proj + xavier_uniform_(self.value_proj.weight) + constant_(self.value_proj.bias) + xavier_uniform_(self.output_proj.weight) + constant_(self.output_proj.bias) + + def forward(self, + query, + key, + value, + reference_points, + value_spatial_shapes, + value_level_start_index, + attn_mask=None, + **kwargs): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] + attn_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + assert int(value_spatial_shapes.prod(1).sum()) == Len_v + + value = self.value_proj(value) + if attn_mask is not None: + attn_mask = attn_mask.astype(value.dtype).unsqueeze(-1) + value *= attn_mask + value = value.reshape([bs, Len_v, self.num_heads, self.head_dim]) + + sampling_offsets = self.sampling_offsets(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2]) + attention_weights = self.attention_weights(query).reshape( + [bs, Len_q, self.num_heads, self.num_levels * self.num_points]) + attention_weights = F.softmax(attention_weights).reshape( + [bs, Len_q, self.num_heads, self.num_levels, self.num_points]) + + if reference_points.shape[-1] == 2: + offset_normalizer = value_spatial_shapes.flip([1]).reshape( + [1, 1, 1, self.num_levels, 1, 2]) + sampling_locations = reference_points.reshape([ + bs, Len_q, 1, self.num_levels, 1, 2 + ]) + sampling_offsets / offset_normalizer + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + sampling_offsets / + self.num_points * reference_points[:, :, None, :, None, 2:] * + 0.5) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.". + format(reference_points.shape[-1])) + + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) + output = self.output_proj(output) + + return output + + +@register +class MultiScaleDeformablePoseAttention(nn.Layer): + """An attention module used in PETR. `End-to-End Multi-Person + Pose Estimation with Transformers`. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 8. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for + each query in each head. Default: 17. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_residual`. + Default: 0.1. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=17, + im2col_step=64, + dropout=0.1, + norm_cfg=None, + init_cfg=None, + batch_first=False, + lr_mult=0.1): + super().__init__() + if embed_dims % num_heads != 0: + raise ValueError(f'embed_dims must be divisible by num_heads, ' + f'but got {embed_dims} and {num_heads}') + dim_per_head = embed_dims // num_heads + self.norm_cfg = norm_cfg + self.init_cfg = init_cfg + self.dropout = nn.Dropout(dropout) + self.batch_first = batch_first + + # you'd better set dim_per_head to a power of 2 + # which is more efficient in the CUDA implementation + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + if not _is_power_of_2(dim_per_head): + warnings.warn("You'd better set embed_dims in " + 'MultiScaleDeformAttention to make ' + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, + num_heads * num_levels * num_points * 2, + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult)) + self.attention_weights = nn.Linear(embed_dims, + num_heads * num_levels * num_points) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + + try: + # use cuda op + from deformable_detr_ops import ms_deformable_attn + except: + # use paddle func + from .utils import deformable_attention_core_func as ms_deformable_attn + self.ms_deformable_attn_core = ms_deformable_attn + + self.init_weights() + + def init_weights(self): + """Default initialization for Parameters of Module.""" + constant_(self.sampling_offsets.weight) + constant_(self.sampling_offsets.bias) + constant_(self.attention_weights.weight) + constant_(self.attention_weights.bias) + xavier_uniform_(self.value_proj.weight) + constant_(self.value_proj.bias) + xavier_uniform_(self.output_proj.weight) + constant_(self.output_proj.bias) + + def forward(self, + query, + key, + value, + residual=None, + attn_mask=None, + reference_points=None, + value_spatial_shapes=None, + value_level_start_index=None, + **kwargs): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + key (Tensor): The key tensor with shape (num_key, bs, embed_dims). + value (Tensor): The value tensor with shape + (num_key, bs, embed_dims). + residual (Tensor): The tensor used for addition, with the + same shape as `x`. Default None. If None, `x` will be used. + reference_points (Tensor): The normalized reference points with + shape (bs, num_query, num_levels, K*2), all elements is range + in [0, 1], top-left (0,0), bottom-right (1, 1), including + padding area. + attn_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + value_spatial_shapes (Tensor): Spatial shape of features in + different level. With shape (num_levels, 2), + last dimension represent (h, w). + value_level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + + if key is None: + key = query + if value is None: + value = key + + bs, num_query, _ = query.shape + bs, num_key, _ = value.shape + assert (value_spatial_shapes[:, 0].numpy() * + value_spatial_shapes[:, 1].numpy()).sum() == num_key + + value = self.value_proj(value) + if attn_mask is not None: + # value = value.masked_fill(attn_mask[..., None], 0.0) + value *= attn_mask.unsqueeze(-1) + value = value.reshape([bs, num_key, self.num_heads, -1]) + sampling_offsets = self.sampling_offsets(query).reshape([ + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2 + ]) + attention_weights = self.attention_weights(query).reshape( + [bs, num_query, self.num_heads, self.num_levels * self.num_points]) + attention_weights = F.softmax(attention_weights, axis=-1) + + attention_weights = attention_weights.reshape( + [bs, num_query, self.num_heads, self.num_levels, self.num_points]) + if reference_points.shape[-1] == self.num_points * 2: + reference_points_reshape = reference_points.reshape( + (bs, num_query, self.num_levels, -1, 2)).unsqueeze(2) + x1 = reference_points[:, :, :, 0::2].min(axis=-1, keepdim=True) + y1 = reference_points[:, :, :, 1::2].min(axis=-1, keepdim=True) + x2 = reference_points[:, :, :, 0::2].max(axis=-1, keepdim=True) + y2 = reference_points[:, :, :, 1::2].max(axis=-1, keepdim=True) + w = paddle.clip(x2 - x1, min=1e-4) + h = paddle.clip(y2 - y1, min=1e-4) + wh = paddle.concat([w, h], axis=-1)[:, :, None, :, None, :] + + sampling_locations = reference_points_reshape \ + + sampling_offsets * wh * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2K, but get {reference_points.shape[-1]} instead.') + + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights) + + output = self.output_proj(output) + return output + + +@register +class PETR_TransformerDecoderLayer(nn.Layer): + __inject__ = ['self_attn', 'cross_attn'] + + def __init__(self, + d_model, + nhead=8, + self_attn=None, + cross_attn=None, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False): + super(PETR_TransformerDecoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + if self_attn is None: + self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + else: + self.self_attn = self_attn + if cross_attn is None: + self.cross_attn = MultiHeadAttention(d_model, nhead, attn_dropout) + else: + self.cross_attn = cross_attn + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train") + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout3 = nn.Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + self._reset_parameters() + + def _reset_parameters(self): + linear_init_(self.linear1) + linear_init_(self.linear2) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + pos_embed=None, + query_pos_embed=None, + **kwargs): + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + q = k = self.with_pos_embed(tgt, query_pos_embed) + tgt = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask) + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + q = self.with_pos_embed(tgt, query_pos_embed) + key_tmp = tgt + # k = self.with_pos_embed(memory, pos_embed) + tgt = self.cross_attn( + q, key=key_tmp, value=memory, attn_mask=memory_mask, **kwargs) + tgt = residual + self.dropout2(tgt) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +@register +class PETR_TransformerDecoder(nn.Layer): + """Implements the decoder in PETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + __inject__ = ['decoder_layer'] + + def __init__(self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False, + num_keypoints=17, + **kwargs): + super(PETR_TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + self.num_keypoints = num_keypoints + + def forward(self, + query, + *args, + reference_points=None, + valid_ratios=None, + kpt_branches=None, + **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape (num_query, bs, embed_dims). + reference_points (Tensor): The reference points of offset, + has shape (bs, num_query, K*2). + valid_ratios (Tensor): The radios of valid points on the feature + map, has shape (bs, num_levels, 2). + kpt_branches: (obj:`nn.LayerList`): Used for refining the + regression results. Only would be passed when `with_box_refine` + is True, otherwise would be passed a `None`. + + Returns: + tuple (Tensor): Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims] and + [num_layers, bs, num_query, K*2]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == self.num_keypoints * 2: + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios.tile((1, 1, self.num_keypoints))[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + *args, + reference_points=reference_points_input, + **kwargs) + + if kpt_branches is not None: + tmp = kpt_branches[lid](output) + if reference_points.shape[-1] == self.num_keypoints * 2: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = F.sigmoid(new_reference_points) + else: + raise NotImplementedError + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return paddle.stack(intermediate), paddle.stack( + intermediate_reference_points) + + return output, reference_points + + +@register +class PETR_DeformableTransformerDecoder(nn.Layer): + __inject__ = ['decoder_layer'] + + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super(PETR_DeformableTransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_mask=None, + query_pos_embed=None): + output = tgt + intermediate = [] + for lid, layer in enumerate(self.layers): + output = layer(output, reference_points, memory, + memory_spatial_shapes, memory_mask, query_pos_embed) + + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + return paddle.stack(intermediate) + + return output.unsqueeze(0) + + +@register +class PETR_DeformableDetrTransformerDecoder(PETR_DeformableTransformerDecoder): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(PETR_DeformableDetrTransformerDecoder, self).__init__(*args, + **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, + query, + *args, + reference_points=None, + valid_ratios=None, + reg_branches=None, + **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.LayerList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] * \ + paddle.concat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + *args, + reference_points=reference_points_input, + **kwargs) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = F.sigmoid(new_reference_points) + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = F.sigmoid(new_reference_points) + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return paddle.stack(intermediate), paddle.stack( + intermediate_reference_points) + + return output, reference_points + + +@register +class PETRTransformer(nn.Layer): + """Implements the PETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + __inject__ = ["encoder", "decoder", "hm_encoder", "refine_decoder"] + + def __init__(self, + encoder="", + decoder="", + hm_encoder="", + refine_decoder="", + as_two_stage=True, + num_feature_levels=4, + two_stage_num_proposals=300, + num_keypoints=17, + **kwargs): + super(PETRTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.num_keypoints = num_keypoints + self.encoder = encoder + self.decoder = decoder + self.embed_dims = self.encoder.embed_dims + self.hm_encoder = hm_encoder + self.refine_decoder = refine_decoder + self.init_layers() + self.init_weights() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + #paddle.create_parameter + self.level_embeds = paddle.create_parameter( + (self.num_feature_levels, self.embed_dims), dtype="float32") + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.refine_query_embedding = nn.Embedding(self.num_keypoints, + self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, + 2 * self.num_keypoints) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.rank() > 1: + xavier_uniform_(p) + if hasattr(p, 'bias') and p.bias is not None: + constant_(p.bais) + for m in self.sublayers(): + if isinstance(m, MSDeformableAttention): + m._reset_parameters() + for m in self.sublayers(): + if isinstance(m, MultiScaleDeformablePoseAttention): + m.init_weights() + if not self.as_two_stage: + xavier_uniform_(self.reference_points.weight) + constant_(self.reference_points.bias) + normal_(self.level_embeds) + normal_(self.refine_query_embedding.weight) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, + spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor): The output of encoder, has shape + (bs, num_key, embed_dim). num_key is equal the number of points + on feature map from all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, has shape + (bs, num_key, embed_dim). num_key is equal the number of + points on feature map from all levels. + - output_proposals (Tensor): The normalized proposal + after a inverse sigmoid, has shape (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].reshape( + [N, H, W, 1]) + valid_H = paddle.sum(mask_flatten_[:, :, 0, 0], 1) + valid_W = paddle.sum(mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = paddle.meshgrid( + paddle.linspace( + 0, H - 1, H, dtype="float32"), + paddle.linspace( + 0, W - 1, W, dtype="float32")) + grid = paddle.concat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], + -1) + + scale = paddle.concat( + [valid_W.unsqueeze(-1), + valid_H.unsqueeze(-1)], 1).reshape([N, 1, 1, 2]) + grid = (grid.unsqueeze(0).expand((N, -1, -1, -1)) + 0.5) / scale + proposal = grid.reshape([N, -1, 2]) + proposals.append(proposal) + _cur += (H * W) + output_proposals = paddle.concat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).all( + -1, keepdim=True).astype("bool") + output_proposals = paddle.log(output_proposals / (1 - output_proposals)) + output_proposals = masked_fill( + output_proposals, ~memory_padding_mask.astype("bool").unsqueeze(-1), + float('inf')) + output_proposals = masked_fill(output_proposals, + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = masked_fill( + output_memory, ~memory_padding_mask.astype("bool").unsqueeze(-1), + float(0)) + output_memory = masked_fill(output_memory, ~output_proposals_valid, + float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all feature maps, + has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid points on the + feature map, has shape (bs, num_levels, 2). + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = paddle.meshgrid( + paddle.linspace( + 0.5, H - 0.5, H, dtype="float32"), + paddle.linspace( + 0.5, W - 0.5, W, dtype="float32")) + ref_y = ref_y.reshape( + (-1, ))[None] / (valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape( + (-1, ))[None] / (valid_ratios[:, None, lvl, 0] * W) + ref = paddle.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = paddle.concat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = paddle.sum(mask[:, :, 0].astype('float'), 1) + valid_W = paddle.sum(mask[:, 0, :].astype('float'), 1) + valid_ratio_h = valid_H.astype('float') / H + valid_ratio_w = valid_W.astype('float') / W + valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, + proposals, + num_pos_feats=128, + temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = paddle.arange(num_pos_feats, dtype="float32") + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = F.sigmoid(proposals) * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = paddle.stack( + (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + axis=4).flatten(2) + return pos + + def forward(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + kpt_branches=None, + cls_branches=None): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from different level. + Each element has shape [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from different + level used for encoder and decoder, each element has shape + [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + kpt_branches (obj:`nn.LayerList`): Keypoint Regression heads for + feature maps from each decoder layer. Only would be passed when + `with_box_refine` is Ture. Default to None. + cls_branches (obj:`nn.LayerList`): Classification heads for + feature maps from each decoder layer. Only would be passed when + `as_two_stage` is Ture. Default to None. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + `return_intermediate_dec` is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of proposals \ + generated from encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_kpt_unact: The regression results generated from \ + encoder's feature maps., has shape (batch, h*w, K*2). + Only would be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed + ) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose((0, 2, 1)) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose((0, 2, 1)) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].reshape( + [1, 1, -1]) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = paddle.concat(feat_flatten, 1) + mask_flatten = paddle.concat(mask_flatten, 1) + lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) + spatial_shapes_cumsum = paddle.to_tensor( + np.array(spatial_shapes).prod(1).cumsum(0)) + spatial_shapes = paddle.to_tensor(spatial_shapes, dtype="int64") + level_start_index = paddle.concat((paddle.zeros( + (1, ), dtype=spatial_shapes.dtype), spatial_shapes_cumsum[:-1])) + valid_ratios = paddle.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = \ + self.get_reference_points(spatial_shapes, + valid_ratios) + + memory = self.encoder( + src=feat_flatten, + pos_embed=lvl_pos_embed_flatten, + src_mask=mask_flatten, + value_spatial_shapes=spatial_shapes, + reference_points=reference_points, + value_level_start_index=level_start_index, + valid_ratios=valid_ratios) + + bs, _, c = memory.shape + + hm_proto = None + if self.training: + hm_memory = paddle.slice( + memory, + starts=level_start_index[0], + ends=level_start_index[1], + axes=[1]) + hm_pos_embed = paddle.slice( + lvl_pos_embed_flatten, + starts=level_start_index[0], + ends=level_start_index[1], + axes=[1]) + hm_mask = paddle.slice( + mask_flatten, + starts=level_start_index[0], + ends=level_start_index[1], + axes=[1]) + hm_reference_points = paddle.slice( + reference_points, + starts=level_start_index[0], + ends=level_start_index[1], + axes=[1])[:, :, :1, :] + + # official code make a mistake of pos_embed to pose_embed, which disable pos_embed + hm_memory = self.hm_encoder( + src=hm_memory, + pose_embed=hm_pos_embed, + src_mask=hm_mask, + value_spatial_shapes=spatial_shapes[[0]], + reference_points=hm_reference_points, + value_level_start_index=level_start_index[0], + valid_ratios=valid_ratios[:, :1, :]) + hm_memory = hm_memory.reshape((bs, spatial_shapes[0, 0], + spatial_shapes[0, 1], -1)) + hm_proto = (hm_memory, mlvl_masks[0]) + + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers]( + output_memory) + enc_outputs_kpt_unact = \ + kpt_branches[self.decoder.num_layers](output_memory) + enc_outputs_kpt_unact[..., 0::2] += output_proposals[..., 0:1] + enc_outputs_kpt_unact[..., 1::2] += output_proposals[..., 1:2] + + topk = self.two_stage_num_proposals + topk_proposals = paddle.topk( + enc_outputs_class[..., 0], topk, axis=1)[1].unsqueeze(-1) + + #paddle.take_along_axis 对应torch.gather + topk_kpts_unact = paddle.take_along_axis(enc_outputs_kpt_unact, + topk_proposals, 1) + topk_kpts_unact = topk_kpts_unact.detach() + + reference_points = F.sigmoid(topk_kpts_unact) + init_reference_out = reference_points + # learnable query and query_pos + query_pos, query = paddle.split( + query_embed, query_embed.shape[1] // c, axis=1) + query_pos = query_pos.unsqueeze(0).expand((bs, -1, -1)) + query = query.unsqueeze(0).expand((bs, -1, -1)) + else: + query_pos, query = paddle.split( + query_embed, query_embed.shape[1] // c, axis=1) + query_pos = query_pos.unsqueeze(0).expand((bs, -1, -1)) + query = query.unsqueeze(0).expand((bs, -1, -1)) + reference_points = F.sigmoid(self.reference_points(query_pos)) + init_reference_out = reference_points + + # decoder + inter_states, inter_references = self.decoder( + query=query, + memory=memory, + query_pos_embed=query_pos, + memory_mask=mask_flatten, + reference_points=reference_points, + value_spatial_shapes=spatial_shapes, + value_level_start_index=level_start_index, + valid_ratios=valid_ratios, + kpt_branches=kpt_branches) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out, \ + inter_references_out, enc_outputs_class, \ + enc_outputs_kpt_unact, hm_proto, memory + return inter_states, init_reference_out, \ + inter_references_out, None, None, None, None, None, hm_proto + + def forward_refine(self, + mlvl_masks, + memory, + reference_points_pose, + img_inds, + kpt_branches=None, + **kwargs): + mask_flatten = [] + spatial_shapes = [] + for lvl, mask in enumerate(mlvl_masks): + bs, h, w = mask.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + mask = mask.flatten(1) + mask_flatten.append(mask) + mask_flatten = paddle.concat(mask_flatten, 1) + spatial_shapes_cumsum = paddle.to_tensor( + np.array( + spatial_shapes, dtype='int64').prod(1).cumsum(0)) + spatial_shapes = paddle.to_tensor(spatial_shapes, dtype="int64") + level_start_index = paddle.concat((paddle.zeros( + (1, ), dtype=spatial_shapes.dtype), spatial_shapes_cumsum[:-1])) + valid_ratios = paddle.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + # pose refinement (17 queries corresponding to 17 keypoints) + # learnable query and query_pos + refine_query_embedding = self.refine_query_embedding.weight + query_pos, query = paddle.split(refine_query_embedding, 2, axis=1) + pos_num = reference_points_pose.shape[0] + query_pos = query_pos.unsqueeze(0).expand((pos_num, -1, -1)) + query = query.unsqueeze(0).expand((pos_num, -1, -1)) + reference_points = reference_points_pose.reshape( + (pos_num, reference_points_pose.shape[1] // 2, 2)) + pos_memory = memory[img_inds] + mask_flatten = mask_flatten[img_inds] + valid_ratios = valid_ratios[img_inds] + if img_inds.size == 1: + pos_memory = pos_memory.unsqueeze(0) + mask_flatten = mask_flatten.unsqueeze(0) + valid_ratios = valid_ratios.unsqueeze(0) + inter_states, inter_references = self.refine_decoder( + query=query, + memory=pos_memory, + query_pos_embed=query_pos, + memory_mask=mask_flatten, + reference_points=reference_points, + value_spatial_shapes=spatial_shapes, + value_level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=kpt_branches, + **kwargs) + # [num_decoder, num_query, bs, embed_dim] + + init_reference_out = reference_points + return inter_states, init_reference_out, inter_references diff --git a/ppdet/utils/visualizer.py b/ppdet/utils/visualizer.py index f7193306c93e0917ee400df3f76f28a3f436df08..1c8560a7453d1abce9815f874d9ca357fc516072 100644 --- a/ppdet/utils/visualizer.py +++ b/ppdet/utils/visualizer.py @@ -238,7 +238,7 @@ def draw_pose(image, 'for example: `pip install matplotlib`.') raise e - skeletons = np.array([item['keypoints'] for item in results]) + skeletons = np.array([item['keypoints'] for item in results]).reshape((-1, 51)) kpt_nums = 17 if len(skeletons) > 0: kpt_nums = int(skeletons.shape[1] / 3)