From 64e5cd96d756ad7ace2dac571c4cb6cc65daa33d Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 23 Apr 2020 12:22:58 +0800 Subject: [PATCH] Add cornernet (#438) * code for cornernet * add doc * refine custom op * refine code from review comments * refine doc * update doc * update code * add cornernet_squeeze_mixup_cosine config * update code * update code --- configs/anchor_free/README.md | 65 +++ configs/anchor_free/cornernet_squeeze.yml | 147 ++++++ .../cornernet_squeeze_dcn_r50_vd_fpn.yml | 165 ++++++ ...et_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml | 169 ++++++ .../cornernet_squeeze_r50_vd_fpn.yml | 157 ++++++ docs/MODEL_ZOO.md | 4 + docs/MODEL_ZOO_cn.md | 4 + docs/featured_model/ANCHOR_FREE_DETECTION.md | 1 + ppdet/data/transform/op_helper.py | 49 ++ ppdet/data/transform/operators.py | 309 ++++++++++- ppdet/ext_op/README.md | 63 +++ ppdet/ext_op/__init__.py | 18 + ppdet/ext_op/cornerpool_lib.py | 189 +++++++ ppdet/ext_op/src/bottom_pool_op.cc | 101 ++++ ppdet/ext_op/src/bottom_pool_op.cu | 104 ++++ ppdet/ext_op/src/left_pool_op.cc | 101 ++++ ppdet/ext_op/src/left_pool_op.cu | 106 ++++ ppdet/ext_op/src/make.sh | 23 + ppdet/ext_op/src/right_pool_op.cc | 101 ++++ ppdet/ext_op/src/right_pool_op.cu | 105 ++++ ppdet/ext_op/src/top_pool_op.cc | 102 ++++ ppdet/ext_op/src/top_pool_op.cu | 104 ++++ ppdet/ext_op/src/util.cu.h | 223 ++++++++ ppdet/ext_op/test/test_corner_pool.py | 120 +++++ ppdet/modeling/anchor_heads/__init__.py | 2 + ppdet/modeling/anchor_heads/corner_head.py | 486 ++++++++++++++++++ ppdet/modeling/architectures/__init__.py | 2 + .../architectures/cascade_mask_rcnn.py | 2 +- .../architectures/cornernet_squeeze.py | 138 +++++ ppdet/modeling/architectures/mask_rcnn.py | 2 +- ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/hourglass.py | 275 ++++++++++ ppdet/modeling/ops.py | 2 +- ppdet/optimizer.py | 10 +- ppdet/utils/coco_eval.py | 3 +- ppdet/utils/eval_utils.py | 9 +- ppdet/utils/post_process.py | 102 +++- 37 files changed, 3527 insertions(+), 38 deletions(-) create mode 100644 configs/anchor_free/README.md create mode 100644 configs/anchor_free/cornernet_squeeze.yml create mode 100644 configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml create mode 100644 configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml create mode 100644 configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml create mode 100644 docs/featured_model/ANCHOR_FREE_DETECTION.md create mode 100644 ppdet/ext_op/README.md create mode 100644 ppdet/ext_op/__init__.py create mode 100644 ppdet/ext_op/cornerpool_lib.py create mode 100644 ppdet/ext_op/src/bottom_pool_op.cc create mode 100644 ppdet/ext_op/src/bottom_pool_op.cu create mode 100644 ppdet/ext_op/src/left_pool_op.cc create mode 100644 ppdet/ext_op/src/left_pool_op.cu create mode 100755 ppdet/ext_op/src/make.sh create mode 100644 ppdet/ext_op/src/right_pool_op.cc create mode 100644 ppdet/ext_op/src/right_pool_op.cu create mode 100755 ppdet/ext_op/src/top_pool_op.cc create mode 100755 ppdet/ext_op/src/top_pool_op.cu create mode 100644 ppdet/ext_op/src/util.cu.h create mode 100755 ppdet/ext_op/test/test_corner_pool.py create mode 100644 ppdet/modeling/anchor_heads/corner_head.py create mode 100644 ppdet/modeling/architectures/cornernet_squeeze.py create mode 100644 ppdet/modeling/backbones/hourglass.py diff --git a/configs/anchor_free/README.md b/configs/anchor_free/README.md new file mode 100644 index 000000000..dcc6de223 --- /dev/null +++ b/configs/anchor_free/README.md @@ -0,0 +1,65 @@ +# Anchor Free系列模型 + +## 内容 +- [简介](#简介) +- [模型库与基线](#模型库与基线) +- [算法细节](#算法细节) +- [如何贡献代码](#如何贡献代码) + +## 简介 +目前主流的检测算法大体分为两类: single-stage和two-stage,其中single-stage的经典算法包括SSD, YOLO等,two-stage方法有RCNN系列模型,两大类算法在[PaddleDetection Model Zoo](../MODEL_ZOO.md)中均有给出,它们的共同特点是先定义一系列密集的,大小不等的anchor区域,再基于这些先验区域进行分类和回归,这种方式极大的受限于anchor自身的设计。随着CornerNet的提出,涌现了多种anchor free方法,PaddleDetection也集成了一系列anchor free算法。 + +## 模型库与基线 +下表中展示了PaddleDetection当前支持的网络结构,具体细节请参考[算法细节](#算法细节)。 + +| | ResNet50 | ResNet50-vd | Hourglass104 | +|:------------------------:|:--------:|:--------------------------:|:------------------------:| +| [CornerNet-Squeeze](#CornerNet-Squeeze) | x | ✓ | ✓ | +| [FCOS](#FCOS) | ✓ | x | x | + + +### 模型库 + +#### COCO数据集上的mAP + +| 网络结构 | 骨干网络 | 图片个数/GPU | 预训练模型 | mAP | FPS | 模型下载 | +|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:| +| CornerNet-Squeeze | Hourglass104 | 14 | 无 | 34.5 | 35.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_hg104.tar) | +| CornerNet-Squeeze | ResNet50-vd | 14 | [faster\_rcnn\_r50\_vd\_fpn\_2x](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar) | 32.7 | 42.45 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_r50_vd_fpn.tar) | +| CornerNet-Squeeze-dcn | ResNet50-vd | 14 | [faster\_rcnn\_dcn\_r50\_vd\_fpn\_2x](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar) | 34.9 | 40.05 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_dcn_r50_vd_fpn.tar) | +| CornerNet-Squeeze-dcn-mixup-cosine* | ResNet50-vd | 14 | [faster\_rcnn\_dcn\_r50\_vd\_fpn\_2x](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar) | 38.2 | 40.05 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.pdparams) | +| FCOS | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 39.8 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_1x.pdparams) | +| FCOS+multiscale_train | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 42.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_multiscale_2x.pdparams) | + +**注意:** + +- 模型FPS在Tesla V100单卡环境中通过tools/eval.py进行测试 +- CornerNet-Squeeze中使用ResNet结构的骨干网络时,加入了FPN结构,骨干网络的输出feature map采用FPN中的P3层输出。 +- \*CornerNet-Squeeze-dcn-mixup-cosine是基于原版CornerNet-Squeeze优化效果最好的模型,在ResNet的骨干网络基础上增加mixup预处理和使用cosine_decay +- FCOS使用GIoU loss、用location分支预测centerness、左上右下角点偏移量归一化和ground truth中心匹配策略 + +## 算法细节 + +### CornerNet-Squeeze + +**简介:** [CornerNet-Squeeze](https://arxiv.org/abs/1904.08900) 在[Cornernet](https://arxiv.org/abs/1808.01244)基础上进行改进,预测目标框的左上角和右下角的位置,同时参考SqueezeNet和MobileNet的特点,优化了CornerNet骨干网络Hourglass-104,大幅提升了模型预测速度,相较于原版[YOLO-v3](https://arxiv.org/abs/1804.02767),在训练精度和推理速度上都具备一定优势。 + +**特点:** + +- 使用corner_pooling获取候选框左上角和右下角的位置 +- 替换Hourglass-104中的residual block为SqueezeNet中的fire-module +- 替换第二层3x3卷积为3x3深度可分离卷积 + + +### FCOS + +**简介:** [FCOS](https://arxiv.org/abs/1904.01355)是一种密集预测的anchor-free检测算法,使用RetinaNet的骨架,直接在feature map上回归目标物体的长宽,并预测物体的类别以及centerness(feature map上像素点离物体中心的偏移程度),centerness最终会作为权重来调整物体得分。 + +**特点:** + +- 利用FPN结构在不同层预测不同scale的物体框,避免了同一feature map像素点处有多个物体框重叠的情况 +- 通过center-ness单层分支预测当前点是否是目标中心,消除低质量误检 + + +## 如何贡献代码 +我们非常欢迎您可以为PaddleDetection中的Anchor Free检测模型提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。 diff --git a/configs/anchor_free/cornernet_squeeze.yml b/configs/anchor_free/cornernet_squeeze.yml new file mode 100644 index 000000000..a3e5ac0c5 --- /dev/null +++ b/configs/anchor_free/cornernet_squeeze.yml @@ -0,0 +1,147 @@ +architecture: CornerNetSqueeze +use_gpu: true +max_iters: 500000 +log_smooth_window: 20 +log_iter: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: NULL +weights: output/cornernet_squeeze/model_final +num_classes: 80 +stack: 2 + +CornerNetSqueeze: + backbone: Hourglass + corner_head: CornerHead + +Hourglass: + dims: [256, 256, 384, 384, 512] + modules: [2, 2, 2, 2, 4] + +CornerHead: + train_batch_size: 14 + test_batch_size: 1 + ae_threshold: 0.5 + num_dets: 100 + top_k: 20 + +PostProcess: + use_soft_nms: true + detections_per_im: 100 + nms_thresh: 0.001 + sigma: 0.5 + +LearningRate: + base_lr: 0.00025 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 450000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: NULL + +TrainReader: + inputs_def: + image_shape: [3, 511, 511] + fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks'] + output_size: 64 + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: False + - !CornerCrop + input_size: 511 + - !Resize + target_dim: 511 + - !RandomFlipImage + prob: 0.5 + - !CornerRandColor + saturation: 0.4 + contrast: 0.4 + brightness: 0.4 + - !Lighting + eigval: [0.2141788, 0.01817699, 0.00341571] + eigvec: [[-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938]] + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: False + is_channel_first: False + - !Permute + to_bgr: False + - !CornerTarget + output_size: [64, 64] + num_classes: 80 + batch_size: 14 + shuffle: true + drop_last: true + worker_num: 2 + use_process: true + drop_empty: false + +EvalReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + batch_size: 1 + drop_empty: false + worker_num: 2 + use_process: true + +TestReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + batch_size: 1 diff --git a/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml new file mode 100644 index 000000000..fbc653068 --- /dev/null +++ b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml @@ -0,0 +1,165 @@ +architecture: CornerNetSqueeze +use_gpu: true +max_iters: 500000 +log_smooth_window: 20 +log_iter: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar +weights: output/cornernet_squeeze_dcn_r50_vd_fpn/model_final +num_classes: 80 +stack: 1 + +CornerNetSqueeze: + backbone: ResNet + fpn: FPN + corner_head: CornerHead + +ResNet: + norm_type: bn + depth: 50 + feature_maps: [3, 4, 5] + freeze_at: 2 + variant: d + dcn_v2_stages: [3, 4, 5] + +FPN: + min_level: 3 + max_level: 6 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125] + +CornerHead: + train_batch_size: 14 + test_batch_size: 1 + ae_threshold: 0.5 + num_dets: 100 + top_k: 20 + +PostProcess: + use_soft_nms: true + detections_per_im: 100 + nms_thresh: 0.001 + sigma: 0.5 + +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 400000 + - 450000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +TrainReader: + inputs_def: + image_shape: [3, 511, 511] + fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks'] + output_size: 64 + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: False + - !CornerCrop + input_size: 511 + - !Resize + target_dim: 511 + - !RandomFlipImage + prob: 0.5 + - !CornerRandColor + saturation: 0.4 + contrast: 0.4 + brightness: 0.4 + - !Lighting + eigval: [0.2141788, 0.01817699, 0.00341571] + eigvec: [[-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938]] + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: False + is_channel_first: False + - !Permute + to_bgr: False + - !CornerTarget + output_size: [64, 64] + num_classes: 80 + batch_size: 14 + shuffle: true + drop_last: true + worker_num: 2 + use_process: true + drop_empty: false + +EvalReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + use_process: true + batch_size: 1 + drop_empty: false + worker_num: 2 + +TestReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + batch_size: 1 diff --git a/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml new file mode 100644 index 000000000..e5ae6cb0d --- /dev/null +++ b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml @@ -0,0 +1,169 @@ +architecture: CornerNetSqueeze +use_gpu: true +max_iters: 500000 +log_smooth_window: 20 +log_iter: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar +weights: output/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine/model_final +num_classes: 80 +stack: 1 + +CornerNetSqueeze: + backbone: ResNet + fpn: FPN + corner_head: CornerHead + +ResNet: + norm_type: bn + depth: 50 + feature_maps: [3, 4, 5] + freeze_at: 2 + variant: d + dcn_v2_stages: [3, 4, 5] + +FPN: + min_level: 3 + max_level: 6 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125] + +CornerHead: + train_batch_size: 14 + test_batch_size: 1 + ae_threshold: 0.5 + num_dets: 100 + top_k: 20 + +PostProcess: + use_soft_nms: true + detections_per_im: 100 + nms_thresh: 0.001 + sigma: 0.5 + +LearningRate: + base_lr: 0.005 + schedulers: + - !CosineDecay + max_iters: 500000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +TrainReader: + inputs_def: + image_shape: [3, 511, 511] + fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks'] + output_size: 64 + max_tag_len: 256 + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: False + with_mixup: True + - !MixupImage + alpha: 1.5 + beta: 1.5 + - !CornerCrop + input_size: 511 + - !Resize + target_dim: 511 + - !RandomFlipImage + prob: 0.5 + - !CornerRandColor + saturation: 0.4 + contrast: 0.4 + brightness: 0.4 + - !Lighting + eigval: [0.2141788, 0.01817699, 0.00341571] + eigvec: [[-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938]] + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: False + is_channel_first: False + - !Permute + to_bgr: False + - !CornerTarget + output_size: [64, 64] + num_classes: 80 + max_tag_len: 256 + batch_size: 14 + shuffle: true + drop_last: true + worker_num: 2 + use_process: true + drop_empty: false + mixup_epoch: 200 + +EvalReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + use_process: true + batch_size: 1 + drop_empty: false + worker_num: 2 + +TestReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + batch_size: 1 diff --git a/configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml b/configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml new file mode 100644 index 000000000..918b49b41 --- /dev/null +++ b/configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml @@ -0,0 +1,157 @@ +architecture: CornerNetSqueeze +use_gpu: true +max_iters: 500000 +log_smooth_window: 20 +log_iter: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar +weights: output/cornernet_squeeze_r50_vd_fpn/model_final +num_classes: 80 +stack: 1 + +CornerNetSqueeze: + backbone: ResNet + fpn: FPN + corner_head: CornerHead + +ResNet: + norm_type: affine_channel + depth: 50 + feature_maps: [3, 4, 5] + freeze_at: 2 + variant: d + +FPN: + min_level: 3 + max_level: 6 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125] + +CornerHead: + train_batch_size: 14 + test_batch_size: 1 + ae_threshold: 0.5 + num_dets: 100 + top_k: 20 + +PostProcess: + use_soft_nms: true + detections_per_im: 100 + nms_thresh: 0.001 + sigma: 0.5 + +LearningRate: + base_lr: 0.0005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 450000 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: NULL + +TrainReader: + inputs_def: + image_shape: [3, 511, 511] + fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks'] + output_size: 64 + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: False + - !CornerCrop + input_size: 511 + - !Resize + target_dim: 511 + - !RandomFlipImage + prob: 0.5 + - !CornerRandColor + saturation: 0.4 + contrast: 0.4 + brightness: 0.4 + - !Lighting + eigval: [0.2141788, 0.01817699, 0.00341571] + eigvec: [[-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938]] + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: False + is_channel_first: False + - !Permute + to_bgr: False + - !CornerTarget + output_size: [64, 64] + num_classes: 80 + batch_size: 14 + shuffle: true + drop_last: true + worker_num: 2 + use_process: true + drop_empty: false + +EvalReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + use_process: true + batch_size: 1 + drop_empty: false + worker_num: 2 + +TestReader: + inputs_def: + fields: ['image', 'im_id', 'ratios', 'borders'] + output_size: 64 + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !CornerCrop + is_train: false + - !CornerRatio + input_size: 511 + output_size: 64 + - !Permute + to_bgr: False + - !NormalizeImage + mean: [0.40789654, 0.44719302, 0.47026115] + std: [0.28863828, 0.27408164, 0.2780983] + is_scale: True + is_channel_first: True + batch_size: 1 diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index c20b090ee..84c8843be 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -214,3 +214,7 @@ Please refer [face detection models](https://github.com/PaddlePaddle/PaddleDetec ### Object Detection in Open Images Dataset V5 Please refer [Open Images Dataset V5 Baseline model](featured_model/OIDV5_BASELINE_MODEL.md) for details. + +### Anchor Free Models + +Please refer [Anchor Free Models](featured_model/ANCHOR_FREE_DETECTION.md) for details. diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md index df7ccf9e5..19ade0e6d 100644 --- a/docs/MODEL_ZOO_cn.md +++ b/docs/MODEL_ZOO_cn.md @@ -204,3 +204,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型 ### 基于Open Images V5数据集的物体检测 详细请参考[Open Images V5数据集基线模型](featured_model/OIDV5_BASELINE_MODEL.md)。 + +### Anchor Free系列模型 + +详细请参考[Anchor Free系列模型](featured_model/ANCHOR_FREE_DETECTION.md)。 diff --git a/docs/featured_model/ANCHOR_FREE_DETECTION.md b/docs/featured_model/ANCHOR_FREE_DETECTION.md new file mode 100644 index 000000000..2563896f2 --- /dev/null +++ b/docs/featured_model/ANCHOR_FREE_DETECTION.md @@ -0,0 +1 @@ +**文档教程请参考:** [ACHOR\_FREE\_DETECTION.md](../../configs/anchor_free/README.md)
diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index 994f70d70..d41efd934 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -393,3 +393,52 @@ def is_poly(segm): assert isinstance(segm, (list, dict)), \ "Invalid segm type: {}".format(type(segm)) return isinstance(segm, list) + + +def gaussian_radius(bbox_size, min_overlap): + height, width = bbox_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + radius1 = (b1 - sq1) / (2 * a1) + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + radius2 = (b2 - sq2) / (2 * a2) + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + radius3 = (b3 + sq3) / (2 * a3) + return min(radius1, radius2, radius3) + + +def draw_gaussian(heatmap, center, radius, k=1, delte=6): + diameter = 2 * radius + 1 + gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte) + + x, y = center + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left: + radius + right] + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + + +def gaussian2D(shape, sigma=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 8b3a47999..750e9b82e 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -42,7 +42,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process, generate_sample_bbox, clip_bbox, data_anchor_sampling, satisfy_sample_constraint_coverage, crop_image_sampling, generate_sample_bbox_square, bbox_area_sampling, - is_poly) + is_poly, gaussian_radius, draw_gaussian) logger = logging.getLogger(__name__) @@ -1243,10 +1243,13 @@ class ColorDistort(BaseOperator): def __call__(self, sample, context=None): img = sample['image'] if self.random_apply: - distortions = np.random.permutation([ - self.apply_brightness, self.apply_contrast, - self.apply_saturation, self.apply_hue - ]) + functions = [ + self.apply_brightness, + self.apply_contrast, + self.apply_saturation, + self.apply_hue, + ] + distortions = np.random.permutation(functions) for func in distortions: img = func(img) sample['image'] = img @@ -1266,6 +1269,66 @@ class ColorDistort(BaseOperator): return sample +@register_op +class CornerRandColor(ColorDistort): + """Random color for CornerNet series models. + Args: + saturation (float): saturation settings. + contrast (float): contrast settings. + brightness (float): brightness settings. + is_scale (bool): whether to scale the input image. + """ + + def __init__(self, + saturation=0.4, + contrast=0.4, + brightness=0.4, + is_scale=True): + super(CornerRandColor, self).__init__( + saturation=saturation, contrast=contrast, brightness=brightness) + self.is_scale = is_scale + + def apply_saturation(self, img, img_gray): + alpha = 1. + np.random.uniform( + low=-self.saturation, high=self.saturation) + self._blend(alpha, img, img_gray[:, :, None]) + return img + + def apply_contrast(self, img, img_gray): + alpha = 1. + np.random.uniform(low=-self.contrast, high=self.contrast) + img_mean = img_gray.mean() + self._blend(alpha, img, img_mean) + return img + + def apply_brightness(self, img, img_gray): + alpha = 1 + np.random.uniform( + low=-self.brightness, high=self.brightness) + img *= alpha + return img + + def _blend(self, alpha, img, img_mean): + img *= alpha + img_mean *= (1 - alpha) + img += img_mean + + def __call__(self, sample, context=None): + img = sample['image'] + if self.is_scale: + img = img.astype(np.float32, copy=False) + img /= 255. + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + functions = [ + self.apply_brightness, + self.apply_contrast, + self.apply_saturation, + ] + distortions = np.random.permutation(functions) + for func in distortions: + img = func(img, img_gray) + sample['image'] = img + return sample + + @register_op class NormalizePermute(BaseOperator): """Normalize and permute channel order. @@ -1672,3 +1735,239 @@ class BboxXYXY2XYWH(BaseOperator): bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2. sample['gt_bbox'] = bbox return sample + + +@register_op +class Lighting(BaseOperator): + """ + Lighting the imagen by eigenvalues and eigenvectors + Args: + eigval (list): eigenvalues + eigvec (list): eigenvectors + alphastd (float): random weight of lighting, 0.1 by default + """ + + def __init__(self, eigval, eigvec, alphastd=0.1): + super(Lighting, self).__init__() + self.alphastd = alphastd + self.eigval = np.array(eigval).astype('float32') + self.eigvec = np.array(eigvec).astype('float32') + + def __call__(self, sample, context=None): + alpha = np.random.normal(scale=self.alphastd, size=(3, )) + sample['image'] += np.dot(self.eigvec, self.eigval * alpha) + return sample + + +@register_op +class CornerTarget(BaseOperator): + """ + Generate targets for CornerNet by ground truth data. + Args: + output_size (int): the size of output heatmaps. + num_classes (int): num of classes. + gaussian_bump (bool): whether to apply gaussian bump on gt targets. + True by default. + gaussian_rad (int): radius of gaussian bump. If it is set to -1, the + radius will be calculated by iou. -1 by default. + gaussian_iou (float): the threshold iou of predicted bbox to gt bbox. + If the iou is larger than threshold, the predicted bboox seems as + positive sample. 0.3 by default + max_tag_len (int): max num of gt box per image. + """ + + def __init__(self, + output_size, + num_classes, + gaussian_bump=True, + gaussian_rad=-1, + gaussian_iou=0.3, + max_tag_len=128): + super(CornerTarget, self).__init__() + self.num_classes = num_classes + self.output_size = output_size + self.gaussian_bump = gaussian_bump + self.gaussian_rad = gaussian_rad + self.gaussian_iou = gaussian_iou + self.max_tag_len = max_tag_len + + def __call__(self, sample, context=None): + tl_heatmaps = np.zeros( + (self.num_classes, self.output_size[0], self.output_size[1]), + dtype=np.float32) + br_heatmaps = np.zeros( + (self.num_classes, self.output_size[0], self.output_size[1]), + dtype=np.float32) + + tl_regrs = np.zeros((self.max_tag_len, 2), dtype=np.float32) + br_regrs = np.zeros((self.max_tag_len, 2), dtype=np.float32) + tl_tags = np.zeros((self.max_tag_len), dtype=np.int64) + br_tags = np.zeros((self.max_tag_len), dtype=np.int64) + tag_masks = np.zeros((self.max_tag_len), dtype=np.uint8) + tag_lens = np.zeros((), dtype=np.int32) + tag_nums = np.zeros((1), dtype=np.int32) + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + keep_inds = ((gt_bbox[:, 2] - gt_bbox[:, 0]) > 0) & \ + ((gt_bbox[:, 3] - gt_bbox[:, 1]) > 0) + gt_bbox = gt_bbox[keep_inds] + gt_class = gt_class[keep_inds] + sample['gt_bbox'] = gt_bbox + sample['gt_class'] = gt_class + width_ratio = self.output_size[1] / sample['w'] + height_ratio = self.output_size[0] / sample['h'] + for i in range(gt_bbox.shape[0]): + width = gt_bbox[i][2] - gt_bbox[i][0] + height = gt_bbox[i][3] - gt_bbox[i][1] + + xtl, ytl = gt_bbox[i][0], gt_bbox[i][1] + xbr, ybr = gt_bbox[i][2], gt_bbox[i][3] + + fxtl = (xtl * width_ratio) + fytl = (ytl * height_ratio) + fxbr = (xbr * width_ratio) + fybr = (ybr * height_ratio) + + xtl = int(fxtl) + ytl = int(fytl) + xbr = int(fxbr) + ybr = int(fybr) + if self.gaussian_bump: + width = math.ceil(width * width_ratio) + height = math.ceil(height * height_ratio) + if self.gaussian_rad == -1: + radius = gaussian_radius((height, width), self.gaussian_iou) + radius = max(0, int(radius)) + else: + radius = self.gaussian_rad + draw_gaussian(tl_heatmaps[gt_class[i][0]], [xtl, ytl], radius) + draw_gaussian(br_heatmaps[gt_class[i][0]], [xbr, ybr], radius) + else: + tl_heatmaps[gt_class[i][0], ytl, xtl] = 1 + br_heatmaps[gt_class[i][0], ybr, xbr] = 1 + + tl_regrs[i, :] = [fxtl - xtl, fytl - ytl] + br_regrs[i, :] = [fxbr - xbr, fybr - ybr] + tl_tags[tag_lens] = ytl * self.output_size[1] + xtl + br_tags[tag_lens] = ybr * self.output_size[1] + xbr + tag_lens += 1 + + tag_masks[:tag_lens] = 1 + + sample['tl_heatmaps'] = tl_heatmaps + sample['br_heatmaps'] = br_heatmaps + sample['tl_regrs'] = tl_regrs + sample['br_regrs'] = br_regrs + sample['tl_tags'] = tl_tags + sample['br_tags'] = br_tags + sample['tag_masks'] = tag_masks + + return sample + + +@register_op +class CornerCrop(BaseOperator): + """ + Random crop for CornerNet + Args: + random_scales (list): scales of output_size to input_size. + border (int): border of corp center + is_train (bool): train or test + input_size (int): size of input image + """ + + def __init__(self, + random_scales=[0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3], + border=128, + is_train=True, + input_size=511): + super(CornerCrop, self).__init__() + self.random_scales = random_scales + self.border = border + self.is_train = is_train + self.input_size = input_size + + def __call__(self, sample, context=None): + im_h, im_w = int(sample['h']), int(sample['w']) + if self.is_train: + scale = np.random.choice(self.random_scales) + height = int(self.input_size * scale) + width = int(self.input_size * scale) + + w_border = self._get_border(self.border, im_w) + h_border = self._get_border(self.border, im_h) + + ctx = np.random.randint(low=w_border, high=im_w - w_border) + cty = np.random.randint(low=h_border, high=im_h - h_border) + + else: + cty, ctx = im_h // 2, im_w // 2 + height = im_h | 127 + width = im_w | 127 + + cropped_image = np.zeros( + (height, width, 3), dtype=sample['image'].dtype) + + x0, x1 = max(ctx - width // 2, 0), min(ctx + width // 2, im_w) + y0, y1 = max(cty - height // 2, 0), min(cty + height // 2, im_h) + + left_w, right_w = ctx - x0, x1 - ctx + top_h, bottom_h = cty - y0, y1 - cty + + # crop image + cropped_ctx, cropped_cty = width // 2, height // 2 + x_slice = slice(int(cropped_ctx - left_w), int(cropped_ctx + right_w)) + y_slice = slice(int(cropped_cty - top_h), int(cropped_cty + bottom_h)) + cropped_image[y_slice, x_slice, :] = sample['image'][y0:y1, x0:x1, :] + + sample['image'] = cropped_image + sample['h'], sample['w'] = height, width + + if self.is_train: + # crop detections + gt_bbox = sample['gt_bbox'] + gt_bbox[:, 0:4:2] -= x0 + gt_bbox[:, 1:4:2] -= y0 + gt_bbox[:, 0:4:2] += cropped_ctx - left_w + gt_bbox[:, 1:4:2] += cropped_cty - top_h + else: + sample['borders'] = np.array( + [ + cropped_cty - top_h, cropped_cty + bottom_h, + cropped_ctx - left_w, cropped_ctx + right_w + ], + dtype=np.float32) + + return sample + + def _get_border(self, border, size): + i = 1 + while size - border // i <= border // i: + i *= 2 + return border // i + + +@register_op +class CornerRatio(BaseOperator): + """ + Ratio of output size to image size + Args: + input_size (int): the size of input size + output_size (int): the size of heatmap + """ + + def __init__(self, input_size=511, output_size=64): + super(CornerRatio, self).__init__() + self.input_size = input_size + self.output_size = output_size + + def __call__(self, sample, context=None): + scale = (self.input_size + 1) // self.output_size + out_height, out_width = (sample['h'] + 1) // scale, ( + sample['w'] + 1) // scale + height_ratio = out_height / float(sample['h']) + width_ratio = out_width / float(sample['w']) + sample['ratios'] = np.array([height_ratio, width_ratio]) + + return sample diff --git a/ppdet/ext_op/README.md b/ppdet/ext_op/README.md new file mode 100644 index 000000000..cc6173b36 --- /dev/null +++ b/ppdet/ext_op/README.md @@ -0,0 +1,63 @@ +# 自定义OP的编译过程 + +**注意:** 编译自定义OP使用的gcc版本须与Paddle编译使用gcc版本一致,Paddle develop每日版本目前采用**gcc 4.8.2**版本编译,若使用每日版本,请使用**gcc 4.8.2**版本编译自定义OP,否则可能出现兼容性问题。 + +## 代码结构 + + - src: 扩展OP C++/CUDA 源码 + - cornerpool_lib.py: Python API封装 + - tests: 各OP单测程序 + + +## 编译自定义OP + +自定义op需要将实现的C++、CUDA代码编译成动态库,```src/mask.sh```中通过g++/nvcc编译,当然您也可以写Makefile或者CMake。 + +编译需要include PaddlePaddle的相关头文件,链接PaddlePaddle的lib库。 头文件和lib库可通过下面命令获取到: + +``` +# python +>>> import paddle +>>> print(paddle.sysconfig.get_include()) +/paddle/pyenv/local/lib/python2.7/site-packages/paddle/include +>>> print(paddle.sysconfig.get_lib()) +/paddle/pyenv/local/lib/python2.7/site-packages/paddle/libs +``` + +我们提供动态库编译脚本如下: + +``` +cd src +sh make.sh +``` + +最终编译会产出`cornerpool_lib.so` + +**说明:** 若使用源码编译安装PaddlePaddle的方式,编译过程中`cmake`未设置`WITH_MKLDNN`的方式, +编译自定义OP时会报错找不到`mkldnn.h`等文件,可在`make.sh`中删除编译命令中的`-DPADDLE_WITH_MKLDNN`选项。 + + +## 执行单测 + +执行下列单测,确保自定义算子可在网络中正确使用: + +``` +# 回到 ext_op 目录,添加 PYTHONPATH +cd .. +export PYTHONPATH=$PYTHONPATH:`pwd` + +# 运行单测 +python test/test_corner_op.py +``` + +单测运行成功会输出提示信息,如下所示: + +``` +. +---------------------------------------------------------------------- +Ran 4 test in 2.858s + +OK +``` + +更多关于如何在框架外部自定义 C++ OP,可阅读[官网说明文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/index_cn.html) diff --git a/ppdet/ext_op/__init__.py b/ppdet/ext_op/__init__.py new file mode 100644 index 000000000..5d38f757f --- /dev/null +++ b/ppdet/ext_op/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from . import cornerpool_lib +from .cornerpool_lib import * + +__all__ = cornerpool_lib.__all__ diff --git a/ppdet/ext_op/cornerpool_lib.py b/ppdet/ext_op/cornerpool_lib.py new file mode 100644 index 000000000..c56fc661a --- /dev/null +++ b/ppdet/ext_op/cornerpool_lib.py @@ -0,0 +1,189 @@ +import os +import paddle.fluid as fluid + +file_dir = os.path.dirname(os.path.abspath(__file__)) +fluid.load_op_library(os.path.join(file_dir, 'src/cornerpool_lib.so')) + +from paddle.fluid.layer_helper import LayerHelper + +__all__ = [ + 'bottom_pool', + 'top_pool', + 'right_pool', + 'left_pool', +] + + +def bottom_pool(input, is_test=False, name=None): + """ + This layer calculates the bottom pooling output based on the input. + Scan the input from top to bottm for the vertical max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of bottom_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.bottom_pool(input) + """ + if is_test: + helper = LayerHelper('bottom_pool', **locals()) + dtype = helper.input_dtype() + output = helper.create_variable_for_type_inference(dtype) + max_map = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="bottom_pool", + inputs={"X": input}, + outputs={"Output": output, + "MaxMap": max_map}) + return output + H = input.shape[2] + i = 1 + output = input + while i < H: + cur = output[:, :, i:, :] + next = output[:, :, :H - i, :] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([output[:, :, :i, :], max_v], axis=2) + i *= 2 + + return output + + +def top_pool(input, is_test=False, name=None): + """ + This layer calculates the top pooling output based on the input. + Scan the input from bottom to top for the vertical max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of top_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.top_pool(input) + """ + if is_test: + helper = LayerHelper('top_pool', **locals()) + dtype = helper.input_dtype() + output = helper.create_variable_for_type_inference(dtype) + max_map = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="top_pool", + inputs={"X": input}, + outputs={"Output": output, + "MaxMap": max_map}) + return output + + H = input.shape[2] + i = 1 + output = input + while i < H: + cur = output[:, :, :H - i, :] + next = output[:, :, i:, :] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([max_v, output[:, :, H - i:, :]], axis=2) + i *= 2 + + return output + + +def right_pool(input, is_test=False, name=None): + """ + This layer calculates the right pooling output based on the input. + Scan the input from left to right for the horizontal max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of right_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.right_pool(input) + """ + if is_test: + helper = LayerHelper('right_pool', **locals()) + dtype = helper.input_dtype() + output = helper.create_variable_for_type_inference(dtype) + max_map = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="right_pool", + inputs={"X": input}, + outputs={"Output": output, + "MaxMap": max_map}) + return output + + W = input.shape[3] + i = 1 + output = input + while i < W: + cur = output[:, :, :, i:] + next = output[:, :, :, :W - i] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([output[:, :, :, :i], max_v], axis=-1) + i *= 2 + + return output + + +def left_pool(input, is_test=False, name=None): + """ + This layer calculates the left pooling output based on the input. + Scan the input from right to left for the horizontal max-pooling. + The output has the same shape with input. + Args: + input(Variable): This input is a Tensor with shape [N, C, H, W]. + The data type is float32 or float64. + Returns: + Variable(Tensor): The output of left_pool, with shape [N, C, H, W]. + The data type is float32 or float64. + Examples: + ..code-block:: python + import paddle.fluid as fluid + import cornerpool_lib + input = fluid.data( + name='input', shape=[2, 64, 10, 10], dtype='float32') + output = corner_pool.left_pool(input) + """ + if is_test: + helper = LayerHelper('left_pool', **locals()) + dtype = helper.input_dtype() + output = helper.create_variable_for_type_inference(dtype) + max_map = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="left_pool", + inputs={"X": input}, + outputs={"Output": output, + "MaxMap": max_map}) + return output + + W = input.shape[3] + i = 1 + output = input + while i < W: + cur = output[:, :, :, :W - i] + next = output[:, :, :, i:] + max_v = fluid.layers.elementwise_max(cur, next) + output = fluid.layers.concat([max_v, output[:, :, :, W - i:]], axis=-1) + i *= 2 + + return output diff --git a/ppdet/ext_op/src/bottom_pool_op.cc b/ppdet/ext_op/src/bottom_pool_op.cc new file mode 100644 index 000000000..6a867d1f1 --- /dev/null +++ b/ppdet/ext_op/src/bottom_pool_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class BottomPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class BottomPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the bottom pooling output based on the input. +Scan the input from top to bottom for the vertical max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class BottomPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class BottomPoolGradDescMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const override { + op->SetType("bottom_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(bottom_pool, + ops::BottomPoolOp, + ops::BottomPoolOpMaker, + ops::BottomPoolGradDescMaker, + ops::BottomPoolGradDescMaker); +REGISTER_OPERATOR(bottom_pool_grad, ops::BottomPoolOpGrad); diff --git a/ppdet/ext_op/src/bottom_pool_op.cu b/ppdet/ext_op/src/bottom_pool_op.cu new file mode 100644 index 000000000..4912ec3c0 --- /dev/null +++ b/ppdet/ext_op/src/bottom_pool_op.cu @@ -0,0 +1,104 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class BottomPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / height); + + auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, false, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data); + } +}; + +template +class BottomPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(bottom_pool, + ops::BottomPoolOpCUDAKernel, + ops::BottomPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(bottom_pool_grad, + ops::BottomPoolGradOpCUDAKernel, + ops::BottomPoolGradOpCUDAKernel); diff --git a/ppdet/ext_op/src/left_pool_op.cc b/ppdet/ext_op/src/left_pool_op.cc new file mode 100644 index 000000000..c2a8f169f --- /dev/null +++ b/ppdet/ext_op/src/left_pool_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class LeftPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class LeftPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the left pooling output based on the input. +Scan the input from right to left for the horizontal max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class LeftPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class LeftPoolGradDescMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const override { + op->SetType("left_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(left_pool, + ops::LeftPoolOp, + ops::LeftPoolOpMaker, + ops::LeftPoolGradDescMaker, + ops::LeftPoolGradDescMaker); +REGISTER_OPERATOR(left_pool_grad, ops::LeftPoolOpGrad); diff --git a/ppdet/ext_op/src/left_pool_op.cu b/ppdet/ext_op/src/left_pool_op.cu new file mode 100644 index 000000000..a5e9323ad --- /dev/null +++ b/ppdet/ext_op/src/left_pool_op.cu @@ -0,0 +1,106 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class LeftPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / width); + + auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, true, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data); + + } +}; + +template +class LeftPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data); + } +}; + + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(left_pool, + ops::LeftPoolOpCUDAKernel, + ops::LeftPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(left_pool_grad, + ops::LeftPoolGradOpCUDAKernel, + ops::LeftPoolGradOpCUDAKernel); diff --git a/ppdet/ext_op/src/make.sh b/ppdet/ext_op/src/make.sh new file mode 100755 index 000000000..bd0d3a3c9 --- /dev/null +++ b/ppdet/ext_op/src/make.sh @@ -0,0 +1,23 @@ +include_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_include())' ) +lib_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_lib())' ) + +echo $include_dir +echo $lib_dir + +OPS='bottom_pool_op top_pool_op right_pool_op left_pool_op' +for op in ${OPS} +do +nvcc ${op}.cu -c -o ${op}.cu.o -ccbin cc -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -DPADDLE_WITH_MKLDNN -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O0 -g -DNVCC \ + -I ${include_dir}/third_party/ \ + -I ${include_dir} +done + +g++ bottom_pool_op.cc bottom_pool_op.cu.o top_pool_op.cc top_pool_op.cu.o right_pool_op.cc right_pool_op.cu.o left_pool_op.cc left_pool_op.cu.o -o cornerpool_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \ + -I ${include_dir}/third_party/ \ + -I ${include_dir} \ + -L ${lib_dir} \ + -L /usr/local/cuda/lib64 -lpaddle_framework -lcudart + +rm *.cu.o + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir diff --git a/ppdet/ext_op/src/right_pool_op.cc b/ppdet/ext_op/src/right_pool_op.cc new file mode 100644 index 000000000..6bf74a1b0 --- /dev/null +++ b/ppdet/ext_op/src/right_pool_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class RightPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class RightPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the right pooling output based on the input. +Scan the input from left to right or the horizontal max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class RightPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class RightPoolGradDescMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const override { + op->SetType("right_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(right_pool, + ops::RightPoolOp, + ops::RightPoolOpMaker, + ops::RightPoolGradDescMaker, + ops::RightPoolGradDescMaker); +REGISTER_OPERATOR(right_pool_grad, ops::RightPoolOpGrad); diff --git a/ppdet/ext_op/src/right_pool_op.cu b/ppdet/ext_op/src/right_pool_op.cu new file mode 100644 index 000000000..08a52ecf1 --- /dev/null +++ b/ppdet/ext_op/src/right_pool_op.cu @@ -0,0 +1,105 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class RightPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / width); + + auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, false, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data); + + } +}; + +template +class RightPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(right_pool, + ops::RightPoolOpCUDAKernel, + ops::RightPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(right_pool_grad, + ops::RightPoolGradOpCUDAKernel, + ops::RightPoolGradOpCUDAKernel); diff --git a/ppdet/ext_op/src/top_pool_op.cc b/ppdet/ext_op/src/top_pool_op.cc new file mode 100755 index 000000000..29cba6660 --- /dev/null +++ b/ppdet/ext_op/src/top_pool_op.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class TopPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + ctx->ShareDim("X", /*->*/ "MaxMap"); + ctx->ShareDim("X", /*->*/ "Output"); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class TopPoolOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddOutput("Output", "Output with same shape as input(X)"); + AddComment( + R"Doc( +This operatio calculates the top pooling output based on the input. +Scan the input from bottom to top for the vertical max-pooling. +The output has the same shape with input. + )Doc"); + } +}; + +class TopPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + +protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "Input(Output@GRAD) should not be null"); + + auto out_grad_name = framework::GradVarName("Output"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Output"))->type(), + ctx.GetPlace()); + } +}; + +template +class TopPoolGradDescMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("top_pool_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetInput("MaxMap", this->Output("MaxMap")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(top_pool, + ops::TopPoolOp, + ops::TopPoolOpMaker, + ops::TopPoolGradDescMaker, + ops::TopPoolGradDescMaker); +REGISTER_OPERATOR(top_pool_grad, ops::TopPoolOpGrad); diff --git a/ppdet/ext_op/src/top_pool_op.cu b/ppdet/ext_op/src/top_pool_op.cu new file mode 100755 index 000000000..f6237fe79 --- /dev/null +++ b/ppdet/ext_op/src/top_pool_op.cu @@ -0,0 +1,104 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +GUnless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include +#include "util.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +class TopPoolOpCUDAKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *max_map = ctx.Output("MaxMap"); + auto *output = ctx.Output("Output"); + auto *x_data = x->data(); + auto x_dims = x->dims(); + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int num = x->numel(); + auto& dev_ctx = ctx.cuda_device_context(); + + int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int blocks = NumBlocks(num / height); + + auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T)); + T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); + auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int)); + int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); + + GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, true, max_val_data, max_ind_data, max_map_data); + + blocks = NumBlocks(num); + ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data); + } +}; + +template +class TopPoolGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* max_map = ctx.Input("MaxMap"); + auto* out_grad = ctx.Input(framework::GradVarName("Output")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto x_dims = x->dims(); + auto& dev_ctx = ctx.cuda_device_context(); + T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); + auto gpu_place = boost::get(dev_ctx.GetPlace()); + + int threads = kNumCUDAThreads; + int NC_num = x_dims[0] * x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int grad_num = in_grad->numel(); + int blocks = NumBlocks(grad_num); + FillConstant<<>>(in_grad_data, 0, grad_num); + + ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(top_pool, + ops::TopPoolOpCUDAKernel, + ops::TopPoolOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(top_pool_grad, + ops::TopPoolGradOpCUDAKernel, + ops::TopPoolGradOpCUDAKernel); diff --git a/ppdet/ext_op/src/util.cu.h b/ppdet/ext_op/src/util.cu.h new file mode 100644 index 000000000..615e45a78 --- /dev/null +++ b/ppdet/ext_op/src/util.cu.h @@ -0,0 +1,223 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/memory/memory.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void FillConstant(T* x, int num, int fill_num) { + CUDA_1D_KERNEL_LOOP(i, fill_num) { + x[i] = static_cast(num); + } +} + +template +__global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int W, + const int axis, const int start, const int end, + T* output) { + int HW_num = H * W; + int length = axis == 2 ? W : H; + int sliced_len = end - start; + int cur_HW_num = length * sliced_len; + // slice input on H or W (axis is 2 or 3) + CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) { + int NC_id = i / cur_HW_num; + int HW_id = i % cur_HW_num; + if (axis == 2){ + output[i] = x[NC_id * HW_num + start * W + HW_id]; + } else if (axis == 3) { + int col = HW_id % sliced_len; + int row = HW_id / sliced_len; + output[i] = x[NC_id * HW_num + row * W + start + col]; + } + } +} + +template +__global__ void MaxOut(const T* input, const int next_ind, const int NC_num, + const int H, const int W, const int axis, + const int start, const int end, T* output) { + int HW_num = H * W; + int length = axis == 2 ? W : H; + T cur = static_cast(0.); + T next = static_cast(0.); + T max_v = static_cast(0.); + int sliced_len = end - start; + int cur_HW_num = length * sliced_len; + // compare cur and next and assign max values to output + CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) { + int NC_id = i / cur_HW_num; + int HW_id = i % cur_HW_num; + + if (axis == 2){ + cur = input[NC_id * HW_num + start * W + HW_id]; + next = input[NC_id * HW_num + next_ind * W + HW_id]; + max_v = cur > next ? cur : next; + output[NC_id * HW_num + start * W + HW_id] = max_v; + } else if (axis == 3) { + int col = HW_id % sliced_len; + int row = HW_id / sliced_len; + cur = input[NC_id * HW_num + row * W + start + col]; + next = input[NC_id * HW_num + row * W + next_ind + col]; + max_v = cur > next ? cur : next; + output[NC_id * HW_num + row * W + start + col] = max_v; + } + __syncthreads(); + } +} + +template +__global__ void UpdateMaxInfo(const T* input, const int NC_num, + const int H, const int W, const int axis, + const int index, T* max_val, int* max_ind) { + int length = axis == 2 ? W : H; + int HW_num = H * W; + T val = static_cast(0.); + CUDA_1D_KERNEL_LOOP(i, NC_num * length) { + int NC_id = i / length; + int length_id = i % length; + if (axis == 2) { + val = input[NC_id * HW_num + index * W + length_id]; + } else if (axis == 3) { + val = input[NC_id * HW_num + length_id * W + index]; + } + if (val > max_val[i]) { + max_val[i] = val; + max_ind[i] = index; + } + __syncthreads(); + } +} + +template +__global__ void ScatterAddOnAxis(const T* input, const int start, const int* max_ind, const int NC_num, const int H, const int W, const int axis, T* output) { + int length = axis == 2 ? W : H; + int HW_num = H * W; + CUDA_1D_KERNEL_LOOP(i, NC_num * length) { + int NC_id = i / length; + int length_id = i % length; + int id_ = max_ind[i]; + if (axis == 2) { + platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id, input[NC_id * HW_num + start * W + length_id]); + //output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num + start * W + length_id]; + } else if (axis == 3) { + platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_, input[NC_id * HW_num + length_id * W + start]); + //output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num + length_id * W + start]; + } + __syncthreads(); + } +} + +template +__global__ void GetMaxInfo(const T* input, const int NC_num, + const int H, const int W, const int axis, + const bool reverse, T* max_val, int* max_ind, + int* max_map) { + int start = 0; + int end = axis == 2 ? H: W; + int s = reverse ? end-1 : start; + int e = reverse ? start-1 : end; + int step = reverse ? -1 : 1; + int len = axis == 2 ? W : H; + int loc = 0; + T val = static_cast(0.); + for (int i = s; ; ) { + if (i == s) { + CUDA_1D_KERNEL_LOOP(j, NC_num * len) { + int NC_id = j / len; + int len_id = j % len; + if (axis == 2) { + loc = NC_id * H * W + i * W + len_id; + } else if (axis == 3){ + loc = NC_id * H * W + len_id * W + i; + } + max_ind[j] = i; + max_map[loc] = max_ind[j]; + max_val[j] = input[loc]; + __syncthreads(); + } + } else { + CUDA_1D_KERNEL_LOOP(j, NC_num * len) { + int NC_id = j / len; + int len_id = j % len; + + if (axis == 2) { + loc = NC_id * H * W + i * W + len_id; + } else if (axis == 3){ + loc = NC_id * H * W + len_id * W + i; + } + val = input[loc]; + T max_v = max_val[j]; + if (val > max_v) { + max_val[j] = val; + max_map[loc] = i; + max_ind[j] = i; + } else { + max_map[loc] = max_ind[j]; + } + __syncthreads(); + } + } + i += step; + if (s < e && i >= e) break; + if (s > e && i <= e) break; + } +} + +template +__global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){ + CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) { + int loc = max_map[i]; + int NC_id = i / (H * W); + int len_id = 0; + if (axis == 2) { + len_id = i % W; + output[i] = input[NC_id * H * W + loc * W + len_id]; + } else { + len_id = i % (H * W) / W; + output[i] = input[NC_id * H * W + len_id * W + loc]; + } + } +} + +template +__global__ void ScatterAddBw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){ + CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) { + int loc = max_map[i]; + int NC_id = i / (H * W); + int len_id = 0; + int offset = 0; + if (axis == 2) { + len_id = i % W; + offset = NC_id * H * W + loc * W + len_id; + } else { + len_id = i % (H * W) / W; + offset = NC_id * H * W + len_id * W + loc; + } + platform::CudaAtomicAdd(output + offset, input[i]); + } +} + +} // namespace operators +} // namespace paddle diff --git a/ppdet/ext_op/test/test_corner_pool.py b/ppdet/ext_op/test/test_corner_pool.py new file mode 100755 index 000000000..ee5e6b07d --- /dev/null +++ b/ppdet/ext_op/test/test_corner_pool.py @@ -0,0 +1,120 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import cornerpool_lib + + +def bottom_pool_np(x): + height = x.shape[2] + output = x.copy() + for ind in range(height): + cur = output[:, :, ind:height, :] + next = output[:, :, :height - ind, :] + output[:, :, ind:height, :] = np.maximum(cur, next) + return output + + +def top_pool_np(x): + height = x.shape[2] + output = x.copy() + for ind in range(height): + cur = output[:, :, :height - ind, :] + next = output[:, :, ind:height, :] + output[:, :, :height - ind, :] = np.maximum(cur, next) + return output + + +def right_pool_np(x): + width = x.shape[3] + output = x.copy() + for ind in range(width): + cur = output[:, :, :, ind:width] + next = output[:, :, :, :width - ind] + output[:, :, :, ind:width] = np.maximum(cur, next) + return output + + +def left_pool_np(x): + width = x.shape[3] + output = x.copy() + for ind in range(width): + cur = output[:, :, :, :width - ind] + next = output[:, :, :, ind:width] + output[:, :, :, :width - ind] = np.maximum(cur, next) + return output + + +class TestRightPoolOp(unittest.TestCase): + def funcmap(self): + self.func_map = { + 'bottom_x': [cornerpool_lib.bottom_pool, bottom_pool_np], + 'top_x': [cornerpool_lib.top_pool, top_pool_np], + 'right_x': [cornerpool_lib.right_pool, right_pool_np], + 'left_x': [cornerpool_lib.left_pool, left_pool_np] + } + + def setup(self): + self.name = 'right_x' + + def test_check_output(self): + self.funcmap() + self.setup() + x_shape = (2, 10, 16, 16) + x_type = "float64" + + sp = fluid.Program() + tp = fluid.Program() + place = fluid.CUDAPlace(0) + + with fluid.program_guard(tp, sp): + x = fluid.layers.data( + name=self.name, + shape=x_shape, + dtype=x_type, + append_batch_size=False) + y = self.func_map[self.name][0](x) + + np.random.seed(0) + x_np = np.random.uniform(-1000, 1000, x_shape).astype(x_type) + + out_np = self.func_map[self.name][1](x_np) + + exe = fluid.Executor(place) + outs = exe.run(tp, feed={self.name: x_np}, fetch_list=[y]) + + self.assertTrue(np.allclose(outs, out_np)) + + +class TestTopPoolOp(TestRightPoolOp): + def setup(self): + self.name = 'top_x' + + +class TestBottomPoolOp(TestRightPoolOp): + def setup(self): + self.name = 'bottom_x' + + +class TestLeftPoolOp(TestRightPoolOp): + def setup(self): + self.name = 'left_x' + + +if __name__ == "__main__": + unittest.main() diff --git a/ppdet/modeling/anchor_heads/__init__.py b/ppdet/modeling/anchor_heads/__init__.py index c6e495598..49640bdf2 100644 --- a/ppdet/modeling/anchor_heads/__init__.py +++ b/ppdet/modeling/anchor_heads/__init__.py @@ -18,8 +18,10 @@ from . import rpn_head from . import yolo_head from . import retina_head from . import fcos_head +from . import corner_head from .rpn_head import * from .yolo_head import * from .retina_head import * from .fcos_head import * +from .corner_head import * diff --git a/ppdet/modeling/anchor_heads/corner_head.py b/ppdet/modeling/anchor_heads/corner_head.py new file mode 100644 index 000000000..fc7c64c19 --- /dev/null +++ b/ppdet/modeling/anchor_heads/corner_head.py @@ -0,0 +1,486 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Constant + +from ..backbones.hourglass import _conv_norm, kaiming_init +from ppdet.core.workspace import register +import numpy as np +try: + import cornerpool_lib +except: + print( + "warning: cornerpool_lib not found, compile in ext_op at first if needed" + ) + +__all__ = ['CornerHead'] + + +def corner_output(x, pool1, pool2, dim, name=None): + p_conv1 = fluid.layers.conv2d( + pool1 + pool2, + filter_size=3, + num_filters=dim, + padding=1, + param_attr=ParamAttr( + name=name + "_p_conv1_weight", + initializer=kaiming_init(pool1 + pool2, 3)), + bias_attr=False, + name=name + '_p_conv1') + p_bn1 = fluid.layers.batch_norm( + p_conv1, + param_attr=ParamAttr(name=name + '_p_bn1_weight'), + bias_attr=ParamAttr(name=name + '_p_bn1_bias'), + moving_mean_name=name + '_p_bn1_running_mean', + moving_variance_name=name + '_p_bn1_running_var', + name=name + '_p_bn1') + + conv1 = fluid.layers.conv2d( + x, + filter_size=1, + num_filters=dim, + param_attr=ParamAttr( + name=name + "_conv1_weight", initializer=kaiming_init(x, 1)), + bias_attr=False, + name=name + '_conv1') + bn1 = fluid.layers.batch_norm( + conv1, + param_attr=ParamAttr(name=name + '_bn1_weight'), + bias_attr=ParamAttr(name=name + '_bn1_bias'), + moving_mean_name=name + '_bn1_running_mean', + moving_variance_name=name + '_bn1_running_var', + name=name + '_bn1') + + relu1 = fluid.layers.relu(p_bn1 + bn1) + conv2 = _conv_norm( + relu1, 3, dim, pad=1, bn_act='relu', name=name + '_conv2') + return conv2 + + +def corner_pool(x, dim, pool1, pool2, is_test=False, name=None): + p1_conv1 = _conv_norm( + x, 3, 128, pad=1, bn_act='relu', name=name + '_p1_conv1') + pool1 = pool1(p1_conv1, is_test=is_test, name=name + '_pool1') + p2_conv1 = _conv_norm( + x, 3, 128, pad=1, bn_act='relu', name=name + '_p2_conv1') + pool2 = pool2(p2_conv1, is_test=is_test, name=name + '_pool2') + + conv2 = corner_output(x, pool1, pool2, dim, name) + return conv2 + + +def gather_feat(feat, ind, batch_size=1): + feats = [] + for bind in range(batch_size): + feat_b = feat[bind] + ind_b = ind[bind] + ind_b.stop_gradient = True + feat_bg = fluid.layers.gather(feat_b, ind_b) + feats.append(fluid.layers.unsqueeze(feat_bg, axes=[0])) + feat_g = fluid.layers.concat(feats, axis=0) + return feat_g + + +def mask_feat(feat, ind, batch_size=1): + feat_t = fluid.layers.transpose(feat, [0, 2, 3, 1]) + C = feat_t.shape[3] + feat_r = fluid.layers.reshape(feat_t, [0, -1, C]) + return gather_feat(feat_r, ind, batch_size) + + +def nms(heat): + hmax = fluid.layers.pool2d(heat, pool_size=3, pool_padding=1) + keep = fluid.layers.cast(heat == hmax, 'float32') + return heat * keep + + +def _topk(scores, batch_size, height, width, K): + scores_r = fluid.layers.reshape(scores, [batch_size, -1]) + topk_scores, topk_inds = fluid.layers.topk(scores_r, K) + topk_clses = topk_inds / (height * width) + topk_inds = topk_inds % (height * width) + topk_ys = fluid.layers.cast(topk_inds / width, 'float32') + topk_xs = fluid.layers.cast(topk_inds % width, 'float32') + return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs + + +def filter_scores(scores, index_list): + for ind in index_list: + tmp = scores * fluid.layers.cast((1 - ind), 'float32') + scores = tmp - fluid.layers.cast(ind, 'float32') + return scores + + +def decode(tl_heat, + br_heat, + tl_tag, + br_tag, + tl_regr, + br_regr, + ae_threshold=1, + num_dets=1000, + K=100, + batch_size=1): + shape = fluid.layers.shape(tl_heat) + H, W = shape[2], shape[3] + + tl_heat = fluid.layers.sigmoid(tl_heat) + br_heat = fluid.layers.sigmoid(br_heat) + + tl_heat_nms = nms(tl_heat) + br_heat_nms = nms(br_heat) + + tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat_nms, batch_size, + H, W, K) + br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat_nms, batch_size, + H, W, K) + tl_ys = fluid.layers.expand( + fluid.layers.reshape(tl_ys, [-1, K, 1]), [1, 1, K]) + tl_xs = fluid.layers.expand( + fluid.layers.reshape(tl_xs, [-1, K, 1]), [1, 1, K]) + br_ys = fluid.layers.expand( + fluid.layers.reshape(br_ys, [-1, 1, K]), [1, K, 1]) + br_xs = fluid.layers.expand( + fluid.layers.reshape(br_xs, [-1, 1, K]), [1, K, 1]) + + tl_regr = mask_feat(tl_regr, tl_inds, batch_size) + br_regr = mask_feat(br_regr, br_inds, batch_size) + tl_regr = fluid.layers.reshape(tl_regr, [-1, K, 1, 2]) + br_regr = fluid.layers.reshape(br_regr, [-1, 1, K, 2]) + + tl_xs = tl_xs + tl_regr[:, :, :, 0] + tl_ys = tl_ys + tl_regr[:, :, :, 1] + br_xs = br_xs + br_regr[:, :, :, 0] + br_ys = br_ys + br_regr[:, :, :, 1] + + bboxes = fluid.layers.stack([tl_xs, tl_ys, br_xs, br_ys], axis=-1) + + tl_tag = mask_feat(tl_tag, tl_inds, batch_size) + br_tag = mask_feat(br_tag, br_inds, batch_size) + tl_tag = fluid.layers.expand( + fluid.layers.reshape(tl_tag, [-1, K, 1]), [1, 1, K]) + br_tag = fluid.layers.expand( + fluid.layers.reshape(br_tag, [-1, 1, K]), [1, K, 1]) + dists = fluid.layers.abs(tl_tag - br_tag) + + tl_scores = fluid.layers.expand( + fluid.layers.reshape(tl_scores, [-1, K, 1]), [1, 1, K]) + br_scores = fluid.layers.expand( + fluid.layers.reshape(br_scores, [-1, 1, K]), [1, K, 1]) + scores = (tl_scores + br_scores) / 2. + + tl_clses = fluid.layers.expand( + fluid.layers.reshape(tl_clses, [-1, K, 1]), [1, 1, K]) + br_clses = fluid.layers.expand( + fluid.layers.reshape(br_clses, [-1, 1, K]), [1, K, 1]) + cls_inds = fluid.layers.cast(tl_clses != br_clses, 'int32') + dist_inds = fluid.layers.cast(dists > ae_threshold, 'int32') + + width_inds = fluid.layers.cast(br_xs < tl_xs, 'int32') + height_inds = fluid.layers.cast(br_ys < tl_ys, 'int32') + + scores = filter_scores(scores, + [cls_inds, dist_inds, width_inds, height_inds]) + scores = fluid.layers.reshape(scores, [-1, K * K]) + + scores, inds = fluid.layers.topk(scores, num_dets) + scores = fluid.layers.reshape(scores, [-1, num_dets, 1]) + + bboxes = fluid.layers.reshape(bboxes, [batch_size, -1, 4]) + bboxes = gather_feat(bboxes, inds, batch_size) + + clses = fluid.layers.reshape(tl_clses, [batch_size, -1, 1]) + clses = gather_feat(clses, inds, batch_size) + + tl_scores = fluid.layers.reshape(tl_scores, [batch_size, -1, 1]) + tl_scores = gather_feat(tl_scores, inds, batch_size) + br_scores = fluid.layers.reshape(br_scores, [batch_size, -1, 1]) + br_scores = gather_feat(br_scores, inds, batch_size) + + bboxes = fluid.layers.cast(bboxes, 'float32') + clses = fluid.layers.cast(clses, 'float32') + return bboxes, scores, tl_scores, br_scores, clses + + +@register +class CornerHead(object): + """ + CornerNet head with corner_pooling + + Args: + train_batch_size(int): batch_size in training process + test_batch_size(int): batch_size in test process, 1 by default + num_classes(int): num of classes, 80 by default + stack(int): stack of backbone, 2 by default + pull_weight(float): weight of pull_loss, 0.1 by default + push_weight(float): weight of push_loss, 0.1 by default + ae_threshold(float|int): threshold for valid distance of predicted tags, 1 by default + num_dets(int): num of detections, 1000 by default + top_k(int): choose top_k pair of corners in prediction, 100 by default + """ + __shared__ = ['num_classes', 'stack'] + + def __init__(self, + train_batch_size, + test_batch_size=1, + num_classes=80, + stack=2, + pull_weight=0.1, + push_weight=0.1, + ae_threshold=1, + num_dets=1000, + top_k=100): + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.num_classes = num_classes + self.stack = stack + self.pull_weight = pull_weight + self.push_weight = push_weight + self.ae_threshold = ae_threshold + self.num_dets = num_dets + self.K = top_k + self.tl_heats = [] + self.br_heats = [] + self.tl_tags = [] + self.br_tags = [] + self.tl_offs = [] + self.br_offs = [] + + def pred_mod(self, x, dim, name=None): + conv0 = _conv_norm( + x, 1, 256, with_bn=False, bn_act='relu', name=name + '_0') + conv1 = fluid.layers.conv2d( + input=conv0, + filter_size=1, + num_filters=dim, + param_attr=ParamAttr( + name=name + "_1_weight", initializer=kaiming_init(conv0, 1)), + bias_attr=ParamAttr( + name=name + "_1_bias", initializer=Constant(-2.19)), + name=name + '_1') + return conv1 + + def get_output(self, input): + for ind in range(self.stack): + cnv = input[ind] + tl_modules = corner_pool( + cnv, + 256, + cornerpool_lib.top_pool, + cornerpool_lib.left_pool, + name='tl_modules_' + str(ind)) + br_modules = corner_pool( + cnv, + 256, + cornerpool_lib.bottom_pool, + cornerpool_lib.right_pool, + name='br_modules_' + str(ind)) + + tl_heat = self.pred_mod( + tl_modules, self.num_classes, name='tl_heats_' + str(ind)) + br_heat = self.pred_mod( + br_modules, self.num_classes, name='br_heats_' + str(ind)) + + tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind)) + br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind)) + + tl_off = self.pred_mod(tl_modules, 2, name='tl_offs_' + str(ind)) + br_off = self.pred_mod(br_modules, 2, name='br_offs_' + str(ind)) + + self.tl_heats.append(tl_heat) + self.br_heats.append(br_heat) + self.tl_tags.append(tl_tag) + self.br_tags.append(br_tag) + self.tl_offs.append(tl_off) + self.br_offs.append(br_off) + + def focal_loss(self, preds, gt, gt_masks): + preds_clip = [] + none_pos = fluid.layers.cast( + fluid.layers.reduce_sum(gt_masks) == 0, 'float32') + none_pos.stop_gradient = True + min = fluid.layers.assign(np.array([1e-4], dtype='float32')) + max = fluid.layers.assign(np.array([1 - 1e-4], dtype='float32')) + for pred in preds: + pred_s = fluid.layers.sigmoid(pred) + pred_min = fluid.layers.elementwise_max(pred_s, min) + pred_max = fluid.layers.elementwise_min(pred_min, max) + preds_clip.append(pred_max) + + ones = fluid.layers.ones_like(gt) + + fg_map = fluid.layers.cast(gt == ones, 'float32') + fg_map.stop_gradient = True + num_pos = fluid.layers.reduce_sum(fg_map) + min_num = fluid.layers.ones_like(num_pos) + num_pos = fluid.layers.elementwise_max(num_pos, min_num) + num_pos.stop_gradient = True + bg_map = fluid.layers.cast(gt < ones, 'float32') + bg_map.stop_gradient = True + neg_weights = fluid.layers.pow(1 - gt, 4) * bg_map + neg_weights.stop_gradient = True + loss = fluid.layers.assign(np.array([0], dtype='float32')) + for ind, pred in enumerate(preds_clip): + pos_loss = fluid.layers.log(pred) * fluid.layers.pow(1 - pred, + 2) * fg_map + + neg_loss = fluid.layers.log(1 - pred) * fluid.layers.pow( + pred, 2) * neg_weights + + pos_loss = fluid.layers.reduce_sum(pos_loss) + neg_loss = fluid.layers.reduce_sum(neg_loss) + focal_loss_ = (neg_loss + pos_loss) / (num_pos + none_pos) + loss -= focal_loss_ + return loss + + def ae_loss(self, tl_tag, br_tag, gt_masks): + num = fluid.layers.reduce_sum(gt_masks, dim=1) + num_stop_gradient = True + tag0 = fluid.layers.squeeze(tl_tag, [2]) + tag1 = fluid.layers.squeeze(br_tag, [2]) + tag_mean = (tag0 + tag1) / 2 + + tag0 = fluid.layers.pow(tag0 - tag_mean, 2) / (num + 1e-4) * gt_masks + tag1 = fluid.layers.pow(tag1 - tag_mean, 2) / (num + 1e-4) * gt_masks + tag0 = fluid.layers.reduce_sum(tag0) + tag1 = fluid.layers.reduce_sum(tag1) + + pull = tag0 + tag1 + + mask_1 = fluid.layers.expand( + fluid.layers.unsqueeze(gt_masks, [1]), [1, gt_masks.shape[1], 1]) + mask_2 = fluid.layers.expand( + fluid.layers.unsqueeze(gt_masks, [2]), [1, 1, gt_masks.shape[1]]) + mask = fluid.layers.cast((mask_1 + mask_2) == 2, 'float32') + mask.stop_gradient = True + + num2 = (num - 1) * num + num2.stop_gradient = True + tag_mean_1 = fluid.layers.expand( + fluid.layers.unsqueeze(tag_mean, [1]), [1, tag_mean.shape[1], 1]) + tag_mean_2 = fluid.layers.expand( + fluid.layers.unsqueeze(tag_mean, [2]), [1, 1, tag_mean.shape[1]]) + dist = tag_mean_1 - tag_mean_2 + dist = 1 - fluid.layers.abs(dist) + dist = fluid.layers.relu(dist) + dist = dist - 1 / (num + 1e-4) + dist = dist / (num2 + 1e-4) + dist = dist * mask + push = fluid.layers.reduce_sum(dist) + return pull, push + + def off_loss(self, off, gt_off, gt_masks): + mask = fluid.layers.unsqueeze(gt_masks, [2]) + mask = fluid.layers.expand_as(mask, gt_off) + mask.stop_gradient = True + off_loss = fluid.layers.smooth_l1(off, gt_off, mask, mask) + off_loss = fluid.layers.reduce_sum(off_loss) + total_num = fluid.layers.reduce_sum(gt_masks) + total_num.stop_gradient = True + return off_loss / (total_num + 1e-4) + + def get_loss(self, targets): + gt_tl_heat = targets['tl_heatmaps'] + gt_br_heat = targets['br_heatmaps'] + gt_masks = targets['tag_masks'] + gt_tl_off = targets['tl_regrs'] + gt_br_off = targets['br_regrs'] + gt_tl_ind = targets['tl_tags'] + gt_br_ind = targets['br_tags'] + gt_masks = fluid.layers.cast(gt_masks, 'float32') + + focal_loss = 0 + focal_loss_ = self.focal_loss(self.tl_heats, gt_tl_heat, gt_masks) + focal_loss += focal_loss_ + focal_loss_ = self.focal_loss(self.br_heats, gt_br_heat, gt_masks) + focal_loss += focal_loss_ + + pull_loss = 0 + push_loss = 0 + + ones = fluid.layers.assign(np.array([1], dtype='float32')) + tl_tags = [ + mask_feat(tl_tag, gt_tl_ind, self.train_batch_size) + for tl_tag in self.tl_tags + ] + br_tags = [ + mask_feat(br_tag, gt_br_ind, self.train_batch_size) + for br_tag in self.br_tags + ] + + pull_loss, push_loss = 0, 0 + + for tl_tag, br_tag in zip(tl_tags, br_tags): + pull, push = self.ae_loss(tl_tag, br_tag, gt_masks) + pull_loss += pull + push_loss += push + + tl_offs = [ + mask_feat(tl_off, gt_tl_ind, self.train_batch_size) + for tl_off in self.tl_offs + ] + br_offs = [ + mask_feat(br_off, gt_br_ind, self.train_batch_size) + for br_off in self.br_offs + ] + + off_loss = 0 + for tl_off, br_off in zip(tl_offs, br_offs): + off_loss += self.off_loss(tl_off, gt_tl_off, gt_masks) + off_loss += self.off_loss(br_off, gt_br_off, gt_masks) + + pull_loss = self.pull_weight * pull_loss + push_loss = self.push_weight * push_loss + + loss = ( + focal_loss + pull_loss + push_loss + off_loss) / len(self.tl_heats) + return {'loss': loss} + + def get_prediction(self, input): + ind = self.stack - 1 + tl_modules = corner_pool( + input, + 256, + cornerpool_lib.top_pool, + cornerpool_lib.left_pool, + is_test=True, + name='tl_modules_' + str(ind)) + br_modules = corner_pool( + input, + 256, + cornerpool_lib.bottom_pool, + cornerpool_lib.right_pool, + is_test=True, + name='br_modules_' + str(ind)) + + tl_heat = self.pred_mod( + tl_modules, self.num_classes, name='tl_heats_' + str(ind)) + br_heat = self.pred_mod( + br_modules, self.num_classes, name='br_heats_' + str(ind)) + + tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind)) + br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind)) + + tl_off = self.pred_mod(tl_modules, 2, name='tl_offs_' + str(ind)) + br_off = self.pred_mod(br_modules, 2, name='br_offs_' + str(ind)) + + return decode(tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off, + self.ae_threshold, self.num_dets, self.K, + self.test_batch_size) diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 652a38312..566d9a612 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -25,6 +25,7 @@ from . import retinanet from . import blazeface from . import faceboxes from . import fcos +from . import cornernet_squeeze from .faster_rcnn import * from .mask_rcnn import * @@ -37,3 +38,4 @@ from .retinanet import * from .blazeface import * from .faceboxes import * from .fcos import * +from .cornernet_squeeze import * diff --git a/ppdet/modeling/architectures/cascade_mask_rcnn.py b/ppdet/modeling/architectures/cascade_mask_rcnn.py index 30180ddac..e0bba61a4 100644 --- a/ppdet/modeling/architectures/cascade_mask_rcnn.py +++ b/ppdet/modeling/architectures/cascade_mask_rcnn.py @@ -408,7 +408,7 @@ class CascadeMaskRCNN(object): box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox'] for key in box_fields: inputs_def[key] = { - 'shape': [6], + 'shape': [None, 6], 'dtype': 'float32', 'lod_level': 1 } diff --git a/ppdet/modeling/architectures/cornernet_squeeze.py b/ppdet/modeling/architectures/cornernet_squeeze.py new file mode 100644 index 000000000..bcba61378 --- /dev/null +++ b/ppdet/modeling/architectures/cornernet_squeeze.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid + +from ppdet.core.workspace import register +import numpy as np + +__all__ = ['CornerNetSqueeze'] + + +def rescale_bboxes(bboxes, ratios, borders): + x1, y1, x2, y2 = fluid.layers.split(bboxes, 4) + x1 = x1 / ratios[:, 1] - borders[:, 2] + x2 = x2 / ratios[:, 1] - borders[:, 2] + y1 = y1 / ratios[:, 0] - borders[:, 0] + y2 = y2 / ratios[:, 0] - borders[:, 0] + return fluid.layers.concat([x1, y1, x2, y2], axis=2) + + +@register +class CornerNetSqueeze(object): + """ + """ + __category__ = 'architecture' + __inject__ = ['backbone', 'corner_head', 'fpn'] + __shared__ = ['num_classes'] + + def __init__(self, + backbone, + corner_head='CornerHead', + num_classes=80, + fpn=None): + super(CornerNetSqueeze, self).__init__() + self.backbone = backbone + self.corner_head = corner_head + self.num_classes = num_classes + self.fpn = fpn + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + body_feats = self.backbone(im) + if self.fpn is not None: + body_feats, _ = self.fpn.get_output(body_feats) + body_feats = [body_feats.values()[-1]] + if mode == 'train': + target_vars = [ + 'tl_heatmaps', 'br_heatmaps', 'tag_masks', 'tl_regrs', + 'br_regrs', 'tl_tags', 'br_tags' + ] + target = {key: feed_vars[key] for key in target_vars} + self.corner_head.get_output(body_feats) + loss = self.corner_head.get_loss(target) + return loss + + elif mode == 'test': + ratios = feed_vars['ratios'] + borders = feed_vars['borders'] + bboxes, scores, tl_scores, br_scores, clses = self.corner_head.get_prediction( + body_feats[-1]) + bboxes = rescale_bboxes(bboxes, ratios, borders) + detections = fluid.layers.concat([clses, scores, bboxes], axis=2) + + detections = detections[0] + return {'bbox': detections} + + def _inputs_def(self, image_shape, output_size, max_tag_len): + im_shape = [None] + image_shape + C = self.num_classes + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'ratios': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 0}, + 'borders': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'tl_heatmaps': {'shape': [None, C, output_size, output_size], 'dtype': 'float32', 'lod_level': 0}, + 'br_heatmaps': {'shape': [None, C, output_size, output_size], 'dtype': 'float32', 'lod_level': 0}, + 'tl_regrs': {'shape': [None, max_tag_len, 2], 'dtype': 'float32', 'lod_level': 0}, + 'br_regrs': {'shape': [None, max_tag_len, 2], 'dtype': 'float32', 'lod_level': 0}, + 'tl_tags': {'shape': [None, max_tag_len], 'dtype': 'int64', 'lod_level': 0}, + 'br_tags': {'shape': [None, max_tag_len], 'dtype': 'int64', 'lod_level': 0}, + 'tag_masks': {'shape': [None, max_tag_len], 'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_id', 'gt_box', 'gt_class', 'tl_heatmaps', + 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', + 'tag_masks' + ], # for train + output_size=64, + max_tag_len=128, + use_dataloader=True, + iterable=False): + inputs_def = self._inputs_def(image_shape, output_size, max_tag_len) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=64, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, mode='train') + + def eval(self, feed_vars): + return self.build(feed_vars, mode='test') + + def test(self, feed_vars): + return self.build(feed_vars, mode='test') diff --git a/ppdet/modeling/architectures/mask_rcnn.py b/ppdet/modeling/architectures/mask_rcnn.py index 1f3f0104a..f7eb28fda 100644 --- a/ppdet/modeling/architectures/mask_rcnn.py +++ b/ppdet/modeling/architectures/mask_rcnn.py @@ -311,7 +311,7 @@ class MaskRCNN(object): box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox'] for key in box_fields: inputs_def[key] = { - 'shape': [6], + 'shape': [None, 6], 'dtype': 'float32', 'lod_level': 1 } diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 2c31e792d..35fc91dcc 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -29,6 +29,7 @@ from . import res2net from . import hrnet from . import hrfpn from . import bfp +from . import hourglass from .resnet import * from .resnext import * @@ -45,3 +46,4 @@ from .res2net import * from .hrnet import * from .hrfpn import * from .bfp import * +from .hourglass import * diff --git a/ppdet/modeling/backbones/hourglass.py b/ppdet/modeling/backbones/hourglass.py new file mode 100644 index 000000000..b38f79bb4 --- /dev/null +++ b/ppdet/modeling/backbones/hourglass.py @@ -0,0 +1,275 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Uniform + +import functools +from ppdet.core.workspace import register +from .resnet import ResNet +import math + +__all__ = ['Hourglass'] + + +def kaiming_init(input, filter_size): + fan_in = input.shape[1] + std = (1.0 / (fan_in * filter_size * filter_size))**0.5 + return Uniform(0. - std, std) + + +def _conv_norm(x, + k, + out_dim, + stride=1, + pad=0, + groups=None, + with_bn=True, + bn_act=None, + ind=None, + name=None): + conv_name = "_conv" if ind is None else "_conv" + str(ind) + bn_name = "_bn" if ind is None else "_bn" + str(ind) + + conv = fluid.layers.conv2d( + input=x, + filter_size=k, + num_filters=out_dim, + stride=stride, + padding=pad, + groups=groups, + param_attr=ParamAttr( + name=name + conv_name + "_weight", initializer=kaiming_init(x, k)), + bias_attr=ParamAttr( + name=name + conv_name + "_bias", initializer=kaiming_init(x, k)) + if not with_bn else False, + name=name + '_output') + if with_bn: + pattr = ParamAttr(name=name + bn_name + '_weight') + battr = ParamAttr(name=name + bn_name + '_bias') + out = fluid.layers.batch_norm( + input=conv, + act=bn_act, + name=name + '_bn_output', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=name + bn_name + '_running_mean', + moving_variance_name=name + bn_name + + '_running_var') if with_bn else conv + else: + out = fluid.layers.relu(conv) + return out + + +def residual_block(x, out_dim, k=3, stride=1, name=None): + p = (k - 1) // 2 + conv1 = _conv_norm( + x, k, out_dim, pad=p, stride=stride, bn_act='relu', ind=1, name=name) + conv2 = _conv_norm(conv1, k, out_dim, pad=p, ind=2, name=name) + + skip = _conv_norm( + x, 1, out_dim, stride=stride, + name=name + '_skip') if stride != 1 or x.shape[1] != out_dim else x + return fluid.layers.elementwise_add( + x=skip, y=conv2, act='relu', name=name + "_add") + + +def fire_block(x, out_dim, sr=2, stride=1, name=None): + conv1 = _conv_norm(x, 1, out_dim // sr, ind=1, name=name) + conv_1x1 = fluid.layers.conv2d( + conv1, + filter_size=1, + num_filters=out_dim // 2, + stride=stride, + param_attr=ParamAttr( + name=name + "_conv_1x1_weight", initializer=kaiming_init(conv1, 1)), + bias_attr=False, + name=name + '_conv_1x1') + conv_3x3 = fluid.layers.conv2d( + conv1, + filter_size=3, + num_filters=out_dim // 2, + stride=stride, + padding=1, + groups=out_dim // sr, + param_attr=ParamAttr( + name=name + "_conv_3x3_weight", initializer=kaiming_init(conv1, 3)), + bias_attr=False, + name=name + '_conv_3x3', + use_cudnn=False) + conv2 = fluid.layers.concat( + [conv_1x1, conv_3x3], axis=1, name=name + '_conv2') + pattr = ParamAttr(name=name + '_bn2_weight') + battr = ParamAttr(name=name + '_bn2_bias') + + bn2 = fluid.layers.batch_norm( + input=conv2, + name=name + '_bn2', + param_attr=pattr, + bias_attr=battr, + moving_mean_name=name + '_bn2_running_mean', + moving_variance_name=name + '_bn2_running_var') + + if stride == 1 and x.shape[1] == out_dim: + return fluid.layers.elementwise_add( + x=bn2, y=x, act='relu', name=name + "_add_relu") + else: + return fluid.layers.relu(bn2, name="_relu") + + +def make_layer(x, in_dim, out_dim, modules, block, name=None): + layers = block(x, out_dim, name=name + '_0') + for i in range(1, modules): + layers = block(layers, out_dim, name=name + '_' + str(i)) + return layers + + +def make_hg_layer(x, in_dim, out_dim, modules, block, name=None): + layers = block(x, out_dim, stride=2, name=name + '_0') + for i in range(1, modules): + layers = block(layers, out_dim, name=name + '_' + str(i)) + return layers + + +def make_layer_revr(x, in_dim, out_dim, modules, block, name=None): + for i in range(modules - 1): + x = block(x, in_dim, name=name + '_' + str(i)) + layers = block(x, out_dim, name=name + '_' + str(modules - 1)) + return layers + + +def make_unpool_layer(x, dim, name=None): + pattr = ParamAttr(name=name + '_weight', initializer=kaiming_init(x, 4)) + battr = ParamAttr(name=name + '_bias', initializer=kaiming_init(x, 4)) + layer = fluid.layers.conv2d_transpose( + input=x, + num_filters=dim, + filter_size=4, + stride=2, + padding=1, + param_attr=pattr, + bias_attr=battr) + return layer + + +@register +class Hourglass(object): + """ + Hourglass Network, see https://arxiv.org/abs/1603.06937 + Args: + stack (int): stack of hourglass, 2 by default + dims (list): dims of each level in hg_module + modules (list): num of modules in each level + """ + __shared__ = ['stack'] + + def __init__(self, + stack=2, + dims=[256, 256, 384, 384, 512], + modules=[2, 2, 2, 2, 4], + block_name='fire'): + super(Hourglass, self).__init__() + self.stack = stack + assert len(dims) == len(modules), \ + "Expected len of dims equal to len of modules, Receiced len of "\ + "dims: {}, len of modules: {}".format(len(dims), len(modules)) + self.dims = dims + self.modules = modules + self.num_level = len(dims) - 1 + block_dict = {'fire': fire_block} + self.block = block_dict[block_name] + + def __call__(self, input, name='hg'): + inter = self.pre(input, name + '_pre') + cnvs = [] + for ind in range(self.stack): + hg = self.hg_module( + inter, + self.num_level, + self.dims, + self.modules, + name=name + '_hgs_' + str(ind)) + cnv = _conv_norm( + hg, + 3, + 256, + bn_act='relu', + pad=1, + name=name + '_cnvs_' + str(ind)) + cnvs.append(cnv) + + if ind < self.stack - 1: + inter = _conv_norm( + inter, 1, 256, name=name + '_inters__' + + str(ind)) + _conv_norm( + cnv, 1, 256, name=name + '_cnvs__' + str(ind)) + inter = fluid.layers.relu(inter) + inter = residual_block( + inter, 256, name=name + '_inters_' + str(ind)) + return cnvs + + def pre(self, x, name=None): + conv = _conv_norm( + x, 7, 128, stride=2, pad=3, bn_act='relu', name=name + '_0') + res1 = residual_block(conv, 256, stride=2, name=name + '_1') + res2 = residual_block(res1, 256, stride=2, name=name + '_2') + return res2 + + def hg_module(self, + x, + n=4, + dims=[256, 256, 384, 384, 512], + modules=[2, 2, 2, 2, 4], + make_up_layer=make_layer, + make_hg_layer=make_hg_layer, + make_low_layer=make_layer, + make_hg_layer_revr=make_layer_revr, + make_unpool_layer=make_unpool_layer, + name=None): + curr_mod = modules[0] + next_mod = modules[1] + curr_dim = dims[0] + next_dim = dims[1] + up1 = make_up_layer( + x, curr_dim, curr_dim, curr_mod, self.block, name=name + '_up1') + max1 = x + low1 = make_hg_layer( + max1, curr_dim, next_dim, curr_mod, self.block, name=name + '_low1') + low2 = self.hg_module( + low1, + n - 1, + dims[1:], + modules[1:], + make_up_layer=make_up_layer, + make_hg_layer=make_hg_layer, + make_low_layer=make_low_layer, + make_hg_layer_revr=make_hg_layer_revr, + make_unpool_layer=make_unpool_layer, + name=name + '_low2') if n > 1 else make_low_layer( + low1, + next_dim, + next_dim, + next_mod, + self.block, + name=name + '_low2') + low3 = make_hg_layer_revr( + low2, next_dim, curr_dim, curr_mod, self.block, name=name + '_low3') + up2 = make_unpool_layer(low3, curr_dim, name=name + '_up2') + merg = fluid.layers.elementwise_add(x=up1, y=up2, name=name + '_merg') + return merg diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 28abd477c..ca861cf18 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -378,7 +378,7 @@ class MultiClassSoftNMS(object): fluid.default_main_program(), name='softnms_pred_result', dtype='float32', - shape=[6], + shape=[-1, 6], lod_level=1) fluid.layers.py_func( func=_soft_nms, x=[bboxes, scores], out=pred_result) diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index e7a6eb8ab..1b91bfa34 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -148,10 +148,12 @@ class OptimizerBuilder(): self.optimizer = optimizer def __call__(self, learning_rate): - reg_type = self.regularizer['type'] + 'Decay' - reg_factor = self.regularizer['factor'] - regularization = getattr(regularizer, reg_type)(reg_factor) - + if self.regularizer: + reg_type = self.regularizer['type'] + 'Decay' + reg_factor = self.regularizer['factor'] + regularization = getattr(regularizer, reg_type)(reg_factor) + else: + regularization = None optim_args = self.optimizer.copy() optim_type = optim_args['type'] del optim_args['type'] diff --git a/ppdet/utils/coco_eval.py b/ppdet/utils/coco_eval.py index d2ceeacb4..8ccae76ea 100644 --- a/ppdet/utils/coco_eval.py +++ b/ppdet/utils/coco_eval.py @@ -230,9 +230,10 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False): xywh_res = [] for t in results: bboxes = t['bbox'][0] + if len(t['bbox'][1]) == 0: continue lengths = t['bbox'][1][0] im_ids = np.array(t['im_id'][0]).flatten() - if bboxes.shape == (1, 1) or bboxes is None: + if bboxes.shape == (1, 1) or bboxes is None or len(bboxes) == 0: continue k = 0 diff --git a/ppdet/utils/eval_utils.py b/ppdet/utils/eval_utils.py index 717053ad3..ef1d537b3 100644 --- a/ppdet/utils/eval_utils.py +++ b/ppdet/utils/eval_utils.py @@ -135,7 +135,8 @@ def eval_run(exe, mask_multi_scale_test = multi_scale_test and 'Mask' in cfg.architecture if multi_scale_test: - post_res = mstest_box_post_process(res, cfg) + post_res = mstest_box_post_process(res, multi_scale_test, + cfg.num_classes) res.update(post_res) if mask_multi_scale_test: place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() @@ -156,10 +157,16 @@ def eval_run(exe, if 'mask' in res: from ppdet.utils.post_process import mask_encode res['mask'] = mask_encode(res, resolution) + post_config = getattr(cfg, 'PostProcess', None) + if 'Corner' in cfg.architecture and post_config is not None: + from ppdet.utils.post_process import corner_post_process + corner_post_process(res, post_config, cfg.num_classes) results.append(res) if iter_id % 100 == 0: logger.info('Test iter {}'.format(iter_id)) iter_id += 1 + if len(res['bbox'][1]) == 0: + has_bbox = False images_num += len(res['bbox'][1][0]) if has_bbox else 1 except (StopIteration, fluid.core.EOFException): loader.reset() diff --git a/ppdet/utils/post_process.py b/ppdet/utils/post_process.py index c42fcc5a7..cf2519983 100644 --- a/ppdet/utils/post_process.py +++ b/ppdet/utils/post_process.py @@ -38,7 +38,7 @@ def box_flip(boxes, im_shape): def nms(dets, thresh): """Apply classic DPM-style greedy NMS.""" if dets.shape[0] == 0: - return [] + return dets[[], :] scores = dets[:, 0] x1 = dets[:, 1] y1 = dets[:, 2] @@ -86,8 +86,40 @@ def nms(dets, thresh): ovr = inter / (iarea + areas[j] - inter) if ovr >= thresh: suppressed[j] = 1 - - return np.where(suppressed == 0)[0] + keep = np.where(suppressed == 0)[0] + dets = dets[keep, :] + return dets + + +def soft_nms(dets, sigma, thres): + dets_final = [] + while len(dets) > 0: + maxpos = np.argmax(dets[:, 0]) + dets_final.append(dets[maxpos].copy()) + ts, tx1, ty1, tx2, ty2 = dets[maxpos] + scores = dets[:, 0] + # force remove bbox at maxpos + scores[maxpos] = -1 + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + xx1 = np.maximum(tx1, x1) + yy1 = np.maximum(ty1, y1) + xx2 = np.minimum(tx2, x2) + yy2 = np.minimum(ty2, y2) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas + areas[maxpos] - inter) + weight = np.exp(-(ovr * ovr) / sigma) + scores = scores * weight + idx_keep = np.where(scores >= thres) + dets[:, 0] = scores + dets = dets[idx_keep] + dets_final = np.array(dets_final).reshape(-1, 5) + return dets_final def bbox_area(box): @@ -128,39 +160,49 @@ def box_voting(nms_dets, dets, vote_thresh): return top_dets -def get_nms_result(boxes, scores, cfg): - cls_boxes = [[] for _ in range(cfg.num_classes)] - for j in range(1, cfg.num_classes): - inds = np.where(scores[:, j] > cfg.MultiScaleTEST['score_thresh'])[0] - scores_j = scores[inds, j] - boxes_j = boxes[inds, j * 4:(j + 1) * 4] +def get_nms_result(boxes, + scores, + config, + num_classes, + background_label=0, + labels=None): + has_labels = labels is not None + cls_boxes = [[] for _ in range(num_classes)] + start_idx = 1 if background_label == 0 else 0 + for j in range(start_idx, num_classes): + inds = np.where(labels == j)[0] if has_labels else np.where( + scores[:, j] > config['score_thresh'])[0] + scores_j = scores[inds] if has_labels else scores[inds, j] + boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) * + 4] dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype( np.float32, copy=False) - keep = nms(dets_j, cfg.MultiScaleTEST['nms_thresh']) - nms_dets = dets_j[keep, :] - if cfg.MultiScaleTEST['enable_voting']: - nms_dets = box_voting(nms_dets, dets_j, - cfg.MultiScaleTEST['vote_thresh']) + if config.get('use_soft_nms', False): + nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh']) + else: + nms_dets = nms(dets_j, config['nms_thresh']) + if config.get('enable_voting', False): + nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh']) #add labels - label = np.array([j for _ in range(len(keep))]) + label = np.array([j for _ in range(len(nms_dets))]) nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype( np.float32, copy=False) cls_boxes[j] = nms_dets # Limit to max_per_image detections **over all classes** image_scores = np.hstack( - [cls_boxes[j][:, 1] for j in range(1, cfg.num_classes)]) - if len(image_scores) > cfg.MultiScaleTEST['detections_per_im']: - image_thresh = np.sort(image_scores)[-cfg.MultiScaleTEST[ - 'detections_per_im']] - for j in range(1, cfg.num_classes): + [cls_boxes[j][:, 1] for j in range(start_idx, num_classes)]) + if len(image_scores) > config['detections_per_im']: + image_thresh = np.sort(image_scores)[-config['detections_per_im']] + for j in range(start_idx, num_classes): keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0] cls_boxes[j] = cls_boxes[j][keep, :] - im_results = np.vstack([cls_boxes[j] for j in range(1, cfg.num_classes)]) + im_results = np.vstack( + [cls_boxes[j] for j in range(start_idx, num_classes)]) return im_results -def mstest_box_post_process(result, cfg): +def mstest_box_post_process(result, config, num_classes): """ Multi-scale Test Only available for batch_size=1 now. @@ -173,7 +215,7 @@ def mstest_box_post_process(result, cfg): for k in result.keys(): if 'bbox' in k: boxes = result[k][0] - boxes = np.reshape(boxes, (-1, 4 * cfg.num_classes)) + boxes = np.reshape(boxes, (-1, 4 * num_classes)) scores = result['score' + k[4:]][0] if 'flip' in k: boxes = box_flip(boxes, im_shape) @@ -183,7 +225,7 @@ def mstest_box_post_process(result, cfg): ms_boxes = np.concatenate(ms_boxes) ms_scores = np.concatenate(ms_scores) - bbox_pred = get_nms_result(ms_boxes, ms_scores, cfg) + bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes) post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])}) if use_flip: bbox = bbox_pred[:, 2:] @@ -271,3 +313,15 @@ def mask_encode(results, resolution, thresh_binarize=0.5): im_mask[:, :, np.newaxis], order='F'))[0] segms.append(segm) return segms + + +def corner_post_process(results, config, num_classes): + detections = results['bbox'][0] + keep_inds = (detections[:, 1] > -1) + detections = detections[keep_inds] + labels = detections[:, 0] + scores = detections[:, 1] + boxes = detections[:, 2:6] + cls_boxes = get_nms_result( + boxes, scores, config, num_classes, background_label=-1, labels=labels) + results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])}) -- GitLab