diff --git a/configs/anchor_free/README.md b/configs/anchor_free/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dcc6de223251e205ebd228052472d8df9059868f
--- /dev/null
+++ b/configs/anchor_free/README.md
@@ -0,0 +1,65 @@
+# Anchor Free系列模型
+
+## 内容
+- [简介](#简介)
+- [模型库与基线](#模型库与基线)
+- [算法细节](#算法细节)
+- [如何贡献代码](#如何贡献代码)
+
+## 简介
+目前主流的检测算法大体分为两类: single-stage和two-stage,其中single-stage的经典算法包括SSD, YOLO等,two-stage方法有RCNN系列模型,两大类算法在[PaddleDetection Model Zoo](../MODEL_ZOO.md)中均有给出,它们的共同特点是先定义一系列密集的,大小不等的anchor区域,再基于这些先验区域进行分类和回归,这种方式极大的受限于anchor自身的设计。随着CornerNet的提出,涌现了多种anchor free方法,PaddleDetection也集成了一系列anchor free算法。
+
+## 模型库与基线
+下表中展示了PaddleDetection当前支持的网络结构,具体细节请参考[算法细节](#算法细节)。
+
+| | ResNet50 | ResNet50-vd | Hourglass104 |
+|:------------------------:|:--------:|:--------------------------:|:------------------------:|
+| [CornerNet-Squeeze](#CornerNet-Squeeze) | x | ✓ | ✓ |
+| [FCOS](#FCOS) | ✓ | x | x |
+
+
+### 模型库
+
+#### COCO数据集上的mAP
+
+| 网络结构 | 骨干网络 | 图片个数/GPU | 预训练模型 | mAP | FPS | 模型下载 |
+|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|
+| CornerNet-Squeeze | Hourglass104 | 14 | 无 | 34.5 | 35.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_hg104.tar) |
+| CornerNet-Squeeze | ResNet50-vd | 14 | [faster\_rcnn\_r50\_vd\_fpn\_2x](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar) | 32.7 | 42.45 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_r50_vd_fpn.tar) |
+| CornerNet-Squeeze-dcn | ResNet50-vd | 14 | [faster\_rcnn\_dcn\_r50\_vd\_fpn\_2x](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar) | 34.9 | 40.05 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_dcn_r50_vd_fpn.tar) |
+| CornerNet-Squeeze-dcn-mixup-cosine* | ResNet50-vd | 14 | [faster\_rcnn\_dcn\_r50\_vd\_fpn\_2x](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar) | 38.2 | 40.05 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.pdparams) |
+| FCOS | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 39.8 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_1x.pdparams) |
+| FCOS+multiscale_train | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 42.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_multiscale_2x.pdparams) |
+
+**注意:**
+
+- 模型FPS在Tesla V100单卡环境中通过tools/eval.py进行测试
+- CornerNet-Squeeze中使用ResNet结构的骨干网络时,加入了FPN结构,骨干网络的输出feature map采用FPN中的P3层输出。
+- \*CornerNet-Squeeze-dcn-mixup-cosine是基于原版CornerNet-Squeeze优化效果最好的模型,在ResNet的骨干网络基础上增加mixup预处理和使用cosine_decay
+- FCOS使用GIoU loss、用location分支预测centerness、左上右下角点偏移量归一化和ground truth中心匹配策略
+
+## 算法细节
+
+### CornerNet-Squeeze
+
+**简介:** [CornerNet-Squeeze](https://arxiv.org/abs/1904.08900) 在[Cornernet](https://arxiv.org/abs/1808.01244)基础上进行改进,预测目标框的左上角和右下角的位置,同时参考SqueezeNet和MobileNet的特点,优化了CornerNet骨干网络Hourglass-104,大幅提升了模型预测速度,相较于原版[YOLO-v3](https://arxiv.org/abs/1804.02767),在训练精度和推理速度上都具备一定优势。
+
+**特点:**
+
+- 使用corner_pooling获取候选框左上角和右下角的位置
+- 替换Hourglass-104中的residual block为SqueezeNet中的fire-module
+- 替换第二层3x3卷积为3x3深度可分离卷积
+
+
+### FCOS
+
+**简介:** [FCOS](https://arxiv.org/abs/1904.01355)是一种密集预测的anchor-free检测算法,使用RetinaNet的骨架,直接在feature map上回归目标物体的长宽,并预测物体的类别以及centerness(feature map上像素点离物体中心的偏移程度),centerness最终会作为权重来调整物体得分。
+
+**特点:**
+
+- 利用FPN结构在不同层预测不同scale的物体框,避免了同一feature map像素点处有多个物体框重叠的情况
+- 通过center-ness单层分支预测当前点是否是目标中心,消除低质量误检
+
+
+## 如何贡献代码
+我们非常欢迎您可以为PaddleDetection中的Anchor Free检测模型提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。
diff --git a/configs/anchor_free/cornernet_squeeze.yml b/configs/anchor_free/cornernet_squeeze.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a3e5ac0c507c4555092550c880774b7a49293993
--- /dev/null
+++ b/configs/anchor_free/cornernet_squeeze.yml
@@ -0,0 +1,147 @@
+architecture: CornerNetSqueeze
+use_gpu: true
+max_iters: 500000
+log_smooth_window: 20
+log_iter: 20
+save_dir: output
+snapshot_iter: 10000
+metric: COCO
+pretrain_weights: NULL
+weights: output/cornernet_squeeze/model_final
+num_classes: 80
+stack: 2
+
+CornerNetSqueeze:
+ backbone: Hourglass
+ corner_head: CornerHead
+
+Hourglass:
+ dims: [256, 256, 384, 384, 512]
+ modules: [2, 2, 2, 2, 4]
+
+CornerHead:
+ train_batch_size: 14
+ test_batch_size: 1
+ ae_threshold: 0.5
+ num_dets: 100
+ top_k: 20
+
+PostProcess:
+ use_soft_nms: true
+ detections_per_im: 100
+ nms_thresh: 0.001
+ sigma: 0.5
+
+LearningRate:
+ base_lr: 0.00025
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones:
+ - 450000
+
+OptimizerBuilder:
+ optimizer:
+ type: Adam
+ regularizer: NULL
+
+TrainReader:
+ inputs_def:
+ image_shape: [3, 511, 511]
+ fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: annotations/instances_train2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: False
+ - !CornerCrop
+ input_size: 511
+ - !Resize
+ target_dim: 511
+ - !RandomFlipImage
+ prob: 0.5
+ - !CornerRandColor
+ saturation: 0.4
+ contrast: 0.4
+ brightness: 0.4
+ - !Lighting
+ eigval: [0.2141788, 0.01817699, 0.00341571]
+ eigvec: [[-0.58752847, -0.69563484, 0.41340352],
+ [-0.5832747, 0.00994535, -0.81221408],
+ [-0.56089297, 0.71832671, 0.41158938]]
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: False
+ is_channel_first: False
+ - !Permute
+ to_bgr: False
+ - !CornerTarget
+ output_size: [64, 64]
+ num_classes: 80
+ batch_size: 14
+ shuffle: true
+ drop_last: true
+ worker_num: 2
+ use_process: true
+ drop_empty: false
+
+EvalReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: val2017
+ anno_path: annotations/instances_val2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ batch_size: 1
+ drop_empty: false
+ worker_num: 2
+ use_process: true
+
+TestReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !ImageFolder
+ anno_path: annotations/instances_val2017.json
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ batch_size: 1
diff --git a/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fbc65306868333cd7aabe778b7ee423ceddb357e
--- /dev/null
+++ b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn.yml
@@ -0,0 +1,165 @@
+architecture: CornerNetSqueeze
+use_gpu: true
+max_iters: 500000
+log_smooth_window: 20
+log_iter: 20
+save_dir: output
+snapshot_iter: 10000
+metric: COCO
+pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar
+weights: output/cornernet_squeeze_dcn_r50_vd_fpn/model_final
+num_classes: 80
+stack: 1
+
+CornerNetSqueeze:
+ backbone: ResNet
+ fpn: FPN
+ corner_head: CornerHead
+
+ResNet:
+ norm_type: bn
+ depth: 50
+ feature_maps: [3, 4, 5]
+ freeze_at: 2
+ variant: d
+ dcn_v2_stages: [3, 4, 5]
+
+FPN:
+ min_level: 3
+ max_level: 6
+ num_chan: 256
+ spatial_scale: [0.03125, 0.0625, 0.125]
+
+CornerHead:
+ train_batch_size: 14
+ test_batch_size: 1
+ ae_threshold: 0.5
+ num_dets: 100
+ top_k: 20
+
+PostProcess:
+ use_soft_nms: true
+ detections_per_im: 100
+ nms_thresh: 0.001
+ sigma: 0.5
+
+LearningRate:
+ base_lr: 0.0005
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones:
+ - 400000
+ - 450000
+ - !LinearWarmup
+ start_factor: 0.
+ steps: 4000
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+
+TrainReader:
+ inputs_def:
+ image_shape: [3, 511, 511]
+ fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: annotations/instances_train2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: False
+ - !CornerCrop
+ input_size: 511
+ - !Resize
+ target_dim: 511
+ - !RandomFlipImage
+ prob: 0.5
+ - !CornerRandColor
+ saturation: 0.4
+ contrast: 0.4
+ brightness: 0.4
+ - !Lighting
+ eigval: [0.2141788, 0.01817699, 0.00341571]
+ eigvec: [[-0.58752847, -0.69563484, 0.41340352],
+ [-0.5832747, 0.00994535, -0.81221408],
+ [-0.56089297, 0.71832671, 0.41158938]]
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: False
+ is_channel_first: False
+ - !Permute
+ to_bgr: False
+ - !CornerTarget
+ output_size: [64, 64]
+ num_classes: 80
+ batch_size: 14
+ shuffle: true
+ drop_last: true
+ worker_num: 2
+ use_process: true
+ drop_empty: false
+
+EvalReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: val2017
+ anno_path: annotations/instances_val2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ use_process: true
+ batch_size: 1
+ drop_empty: false
+ worker_num: 2
+
+TestReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !ImageFolder
+ anno_path: annotations/instances_val2017.json
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ batch_size: 1
diff --git a/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e5ae6cb0d8dbba184edb9d4ce7b1b72b2b234410
--- /dev/null
+++ b/configs/anchor_free/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine.yml
@@ -0,0 +1,169 @@
+architecture: CornerNetSqueeze
+use_gpu: true
+max_iters: 500000
+log_smooth_window: 20
+log_iter: 20
+save_dir: output
+snapshot_iter: 10000
+metric: COCO
+pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar
+weights: output/cornernet_squeeze_dcn_r50_vd_fpn_mixup_cosine/model_final
+num_classes: 80
+stack: 1
+
+CornerNetSqueeze:
+ backbone: ResNet
+ fpn: FPN
+ corner_head: CornerHead
+
+ResNet:
+ norm_type: bn
+ depth: 50
+ feature_maps: [3, 4, 5]
+ freeze_at: 2
+ variant: d
+ dcn_v2_stages: [3, 4, 5]
+
+FPN:
+ min_level: 3
+ max_level: 6
+ num_chan: 256
+ spatial_scale: [0.03125, 0.0625, 0.125]
+
+CornerHead:
+ train_batch_size: 14
+ test_batch_size: 1
+ ae_threshold: 0.5
+ num_dets: 100
+ top_k: 20
+
+PostProcess:
+ use_soft_nms: true
+ detections_per_im: 100
+ nms_thresh: 0.001
+ sigma: 0.5
+
+LearningRate:
+ base_lr: 0.005
+ schedulers:
+ - !CosineDecay
+ max_iters: 500000
+ - !LinearWarmup
+ start_factor: 0.
+ steps: 4000
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+
+TrainReader:
+ inputs_def:
+ image_shape: [3, 511, 511]
+ fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks']
+ output_size: 64
+ max_tag_len: 256
+ dataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: annotations/instances_train2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: False
+ with_mixup: True
+ - !MixupImage
+ alpha: 1.5
+ beta: 1.5
+ - !CornerCrop
+ input_size: 511
+ - !Resize
+ target_dim: 511
+ - !RandomFlipImage
+ prob: 0.5
+ - !CornerRandColor
+ saturation: 0.4
+ contrast: 0.4
+ brightness: 0.4
+ - !Lighting
+ eigval: [0.2141788, 0.01817699, 0.00341571]
+ eigvec: [[-0.58752847, -0.69563484, 0.41340352],
+ [-0.5832747, 0.00994535, -0.81221408],
+ [-0.56089297, 0.71832671, 0.41158938]]
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: False
+ is_channel_first: False
+ - !Permute
+ to_bgr: False
+ - !CornerTarget
+ output_size: [64, 64]
+ num_classes: 80
+ max_tag_len: 256
+ batch_size: 14
+ shuffle: true
+ drop_last: true
+ worker_num: 2
+ use_process: true
+ drop_empty: false
+ mixup_epoch: 200
+
+EvalReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: val2017
+ anno_path: annotations/instances_val2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ use_process: true
+ batch_size: 1
+ drop_empty: false
+ worker_num: 2
+
+TestReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !ImageFolder
+ anno_path: annotations/instances_val2017.json
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ batch_size: 1
diff --git a/configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml b/configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..918b49b41a6092bc9aa29df253b2ed6c80bc72f4
--- /dev/null
+++ b/configs/anchor_free/cornernet_squeeze_r50_vd_fpn.yml
@@ -0,0 +1,157 @@
+architecture: CornerNetSqueeze
+use_gpu: true
+max_iters: 500000
+log_smooth_window: 20
+log_iter: 20
+save_dir: output
+snapshot_iter: 10000
+metric: COCO
+pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_2x.tar
+weights: output/cornernet_squeeze_r50_vd_fpn/model_final
+num_classes: 80
+stack: 1
+
+CornerNetSqueeze:
+ backbone: ResNet
+ fpn: FPN
+ corner_head: CornerHead
+
+ResNet:
+ norm_type: affine_channel
+ depth: 50
+ feature_maps: [3, 4, 5]
+ freeze_at: 2
+ variant: d
+
+FPN:
+ min_level: 3
+ max_level: 6
+ num_chan: 256
+ spatial_scale: [0.03125, 0.0625, 0.125]
+
+CornerHead:
+ train_batch_size: 14
+ test_batch_size: 1
+ ae_threshold: 0.5
+ num_dets: 100
+ top_k: 20
+
+PostProcess:
+ use_soft_nms: true
+ detections_per_im: 100
+ nms_thresh: 0.001
+ sigma: 0.5
+
+LearningRate:
+ base_lr: 0.0005
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones:
+ - 450000
+
+OptimizerBuilder:
+ optimizer:
+ type: Adam
+ regularizer: NULL
+
+TrainReader:
+ inputs_def:
+ image_shape: [3, 511, 511]
+ fields: ['image', 'im_id', 'gt_bbox', 'gt_class', 'tl_heatmaps', 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags', 'tag_masks']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: annotations/instances_train2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: False
+ - !CornerCrop
+ input_size: 511
+ - !Resize
+ target_dim: 511
+ - !RandomFlipImage
+ prob: 0.5
+ - !CornerRandColor
+ saturation: 0.4
+ contrast: 0.4
+ brightness: 0.4
+ - !Lighting
+ eigval: [0.2141788, 0.01817699, 0.00341571]
+ eigvec: [[-0.58752847, -0.69563484, 0.41340352],
+ [-0.5832747, 0.00994535, -0.81221408],
+ [-0.56089297, 0.71832671, 0.41158938]]
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: False
+ is_channel_first: False
+ - !Permute
+ to_bgr: False
+ - !CornerTarget
+ output_size: [64, 64]
+ num_classes: 80
+ batch_size: 14
+ shuffle: true
+ drop_last: true
+ worker_num: 2
+ use_process: true
+ drop_empty: false
+
+EvalReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !COCODataSet
+ image_dir: val2017
+ anno_path: annotations/instances_val2017.json
+ dataset_dir: dataset/coco
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ use_process: true
+ batch_size: 1
+ drop_empty: false
+ worker_num: 2
+
+TestReader:
+ inputs_def:
+ fields: ['image', 'im_id', 'ratios', 'borders']
+ output_size: 64
+ dataset:
+ !ImageFolder
+ anno_path: annotations/instances_val2017.json
+ with_background: false
+ sample_transforms:
+ - !DecodeImage
+ to_rgb: false
+ - !CornerCrop
+ is_train: false
+ - !CornerRatio
+ input_size: 511
+ output_size: 64
+ - !Permute
+ to_bgr: False
+ - !NormalizeImage
+ mean: [0.40789654, 0.44719302, 0.47026115]
+ std: [0.28863828, 0.27408164, 0.2780983]
+ is_scale: True
+ is_channel_first: True
+ batch_size: 1
diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md
index c20b090ee453f671e743de7ce2bb9dd0695b3a7b..84c8843bea92dc56d7a177e8ae74da9a5badc210 100644
--- a/docs/MODEL_ZOO.md
+++ b/docs/MODEL_ZOO.md
@@ -214,3 +214,7 @@ Please refer [face detection models](https://github.com/PaddlePaddle/PaddleDetec
### Object Detection in Open Images Dataset V5
Please refer [Open Images Dataset V5 Baseline model](featured_model/OIDV5_BASELINE_MODEL.md) for details.
+
+### Anchor Free Models
+
+Please refer [Anchor Free Models](featured_model/ANCHOR_FREE_DETECTION.md) for details.
diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md
index df7ccf9e5551a27d5cf27103c90aabae0baaf491..19ade0e6dfc5aa98832a724fbae4c77a1be09d22 100644
--- a/docs/MODEL_ZOO_cn.md
+++ b/docs/MODEL_ZOO_cn.md
@@ -204,3 +204,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
### 基于Open Images V5数据集的物体检测
详细请参考[Open Images V5数据集基线模型](featured_model/OIDV5_BASELINE_MODEL.md)。
+
+### Anchor Free系列模型
+
+详细请参考[Anchor Free系列模型](featured_model/ANCHOR_FREE_DETECTION.md)。
diff --git a/docs/featured_model/ANCHOR_FREE_DETECTION.md b/docs/featured_model/ANCHOR_FREE_DETECTION.md
new file mode 100644
index 0000000000000000000000000000000000000000..2563896f2ced185838873b2e9ae520efd8f990f5
--- /dev/null
+++ b/docs/featured_model/ANCHOR_FREE_DETECTION.md
@@ -0,0 +1 @@
+**文档教程请参考:** [ACHOR\_FREE\_DETECTION.md](../../configs/anchor_free/README.md)
diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py
index 994f70d707e842bdc7cd9aa903cf92e149419d6f..d41efd9341a8717577b63e56c67b8e64d69d3393 100644
--- a/ppdet/data/transform/op_helper.py
+++ b/ppdet/data/transform/op_helper.py
@@ -393,3 +393,52 @@ def is_poly(segm):
assert isinstance(segm, (list, dict)), \
"Invalid segm type: {}".format(type(segm))
return isinstance(segm, list)
+
+
+def gaussian_radius(bbox_size, min_overlap):
+ height, width = bbox_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
+ radius1 = (b1 - sq1) / (2 * a1)
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
+ radius2 = (b2 - sq2) / (2 * a2)
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
+ radius3 = (b3 + sq3) / (2 * a3)
+ return min(radius1, radius2, radius3)
+
+
+def draw_gaussian(heatmap, center, radius, k=1, delte=6):
+ diameter = 2 * radius + 1
+ gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte)
+
+ x, y = center
+
+ height, width = heatmap.shape[0:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
+ radius + right]
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
+
+
+def gaussian2D(shape, sigma=1):
+ m, n = [(ss - 1.) / 2. for ss in shape]
+ y, x = np.ogrid[-m:m + 1, -n:n + 1]
+
+ h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
+ return h
diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py
index 8b3a479991889275e6185cc108c573ac9883de99..750e9b82e84fdccf14e9099f2469ca4e64069c61 100644
--- a/ppdet/data/transform/operators.py
+++ b/ppdet/data/transform/operators.py
@@ -42,7 +42,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling,
- is_poly)
+ is_poly, gaussian_radius, draw_gaussian)
logger = logging.getLogger(__name__)
@@ -1243,10 +1243,13 @@ class ColorDistort(BaseOperator):
def __call__(self, sample, context=None):
img = sample['image']
if self.random_apply:
- distortions = np.random.permutation([
- self.apply_brightness, self.apply_contrast,
- self.apply_saturation, self.apply_hue
- ])
+ functions = [
+ self.apply_brightness,
+ self.apply_contrast,
+ self.apply_saturation,
+ self.apply_hue,
+ ]
+ distortions = np.random.permutation(functions)
for func in distortions:
img = func(img)
sample['image'] = img
@@ -1266,6 +1269,66 @@ class ColorDistort(BaseOperator):
return sample
+@register_op
+class CornerRandColor(ColorDistort):
+ """Random color for CornerNet series models.
+ Args:
+ saturation (float): saturation settings.
+ contrast (float): contrast settings.
+ brightness (float): brightness settings.
+ is_scale (bool): whether to scale the input image.
+ """
+
+ def __init__(self,
+ saturation=0.4,
+ contrast=0.4,
+ brightness=0.4,
+ is_scale=True):
+ super(CornerRandColor, self).__init__(
+ saturation=saturation, contrast=contrast, brightness=brightness)
+ self.is_scale = is_scale
+
+ def apply_saturation(self, img, img_gray):
+ alpha = 1. + np.random.uniform(
+ low=-self.saturation, high=self.saturation)
+ self._blend(alpha, img, img_gray[:, :, None])
+ return img
+
+ def apply_contrast(self, img, img_gray):
+ alpha = 1. + np.random.uniform(low=-self.contrast, high=self.contrast)
+ img_mean = img_gray.mean()
+ self._blend(alpha, img, img_mean)
+ return img
+
+ def apply_brightness(self, img, img_gray):
+ alpha = 1 + np.random.uniform(
+ low=-self.brightness, high=self.brightness)
+ img *= alpha
+ return img
+
+ def _blend(self, alpha, img, img_mean):
+ img *= alpha
+ img_mean *= (1 - alpha)
+ img += img_mean
+
+ def __call__(self, sample, context=None):
+ img = sample['image']
+ if self.is_scale:
+ img = img.astype(np.float32, copy=False)
+ img /= 255.
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ functions = [
+ self.apply_brightness,
+ self.apply_contrast,
+ self.apply_saturation,
+ ]
+ distortions = np.random.permutation(functions)
+ for func in distortions:
+ img = func(img, img_gray)
+ sample['image'] = img
+ return sample
+
+
@register_op
class NormalizePermute(BaseOperator):
"""Normalize and permute channel order.
@@ -1672,3 +1735,239 @@ class BboxXYXY2XYWH(BaseOperator):
bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
sample['gt_bbox'] = bbox
return sample
+
+
+@register_op
+class Lighting(BaseOperator):
+ """
+ Lighting the imagen by eigenvalues and eigenvectors
+ Args:
+ eigval (list): eigenvalues
+ eigvec (list): eigenvectors
+ alphastd (float): random weight of lighting, 0.1 by default
+ """
+
+ def __init__(self, eigval, eigvec, alphastd=0.1):
+ super(Lighting, self).__init__()
+ self.alphastd = alphastd
+ self.eigval = np.array(eigval).astype('float32')
+ self.eigvec = np.array(eigvec).astype('float32')
+
+ def __call__(self, sample, context=None):
+ alpha = np.random.normal(scale=self.alphastd, size=(3, ))
+ sample['image'] += np.dot(self.eigvec, self.eigval * alpha)
+ return sample
+
+
+@register_op
+class CornerTarget(BaseOperator):
+ """
+ Generate targets for CornerNet by ground truth data.
+ Args:
+ output_size (int): the size of output heatmaps.
+ num_classes (int): num of classes.
+ gaussian_bump (bool): whether to apply gaussian bump on gt targets.
+ True by default.
+ gaussian_rad (int): radius of gaussian bump. If it is set to -1, the
+ radius will be calculated by iou. -1 by default.
+ gaussian_iou (float): the threshold iou of predicted bbox to gt bbox.
+ If the iou is larger than threshold, the predicted bboox seems as
+ positive sample. 0.3 by default
+ max_tag_len (int): max num of gt box per image.
+ """
+
+ def __init__(self,
+ output_size,
+ num_classes,
+ gaussian_bump=True,
+ gaussian_rad=-1,
+ gaussian_iou=0.3,
+ max_tag_len=128):
+ super(CornerTarget, self).__init__()
+ self.num_classes = num_classes
+ self.output_size = output_size
+ self.gaussian_bump = gaussian_bump
+ self.gaussian_rad = gaussian_rad
+ self.gaussian_iou = gaussian_iou
+ self.max_tag_len = max_tag_len
+
+ def __call__(self, sample, context=None):
+ tl_heatmaps = np.zeros(
+ (self.num_classes, self.output_size[0], self.output_size[1]),
+ dtype=np.float32)
+ br_heatmaps = np.zeros(
+ (self.num_classes, self.output_size[0], self.output_size[1]),
+ dtype=np.float32)
+
+ tl_regrs = np.zeros((self.max_tag_len, 2), dtype=np.float32)
+ br_regrs = np.zeros((self.max_tag_len, 2), dtype=np.float32)
+ tl_tags = np.zeros((self.max_tag_len), dtype=np.int64)
+ br_tags = np.zeros((self.max_tag_len), dtype=np.int64)
+ tag_masks = np.zeros((self.max_tag_len), dtype=np.uint8)
+ tag_lens = np.zeros((), dtype=np.int32)
+ tag_nums = np.zeros((1), dtype=np.int32)
+
+ gt_bbox = sample['gt_bbox']
+ gt_class = sample['gt_class']
+ keep_inds = ((gt_bbox[:, 2] - gt_bbox[:, 0]) > 0) & \
+ ((gt_bbox[:, 3] - gt_bbox[:, 1]) > 0)
+ gt_bbox = gt_bbox[keep_inds]
+ gt_class = gt_class[keep_inds]
+ sample['gt_bbox'] = gt_bbox
+ sample['gt_class'] = gt_class
+ width_ratio = self.output_size[1] / sample['w']
+ height_ratio = self.output_size[0] / sample['h']
+ for i in range(gt_bbox.shape[0]):
+ width = gt_bbox[i][2] - gt_bbox[i][0]
+ height = gt_bbox[i][3] - gt_bbox[i][1]
+
+ xtl, ytl = gt_bbox[i][0], gt_bbox[i][1]
+ xbr, ybr = gt_bbox[i][2], gt_bbox[i][3]
+
+ fxtl = (xtl * width_ratio)
+ fytl = (ytl * height_ratio)
+ fxbr = (xbr * width_ratio)
+ fybr = (ybr * height_ratio)
+
+ xtl = int(fxtl)
+ ytl = int(fytl)
+ xbr = int(fxbr)
+ ybr = int(fybr)
+ if self.gaussian_bump:
+ width = math.ceil(width * width_ratio)
+ height = math.ceil(height * height_ratio)
+ if self.gaussian_rad == -1:
+ radius = gaussian_radius((height, width), self.gaussian_iou)
+ radius = max(0, int(radius))
+ else:
+ radius = self.gaussian_rad
+ draw_gaussian(tl_heatmaps[gt_class[i][0]], [xtl, ytl], radius)
+ draw_gaussian(br_heatmaps[gt_class[i][0]], [xbr, ybr], radius)
+ else:
+ tl_heatmaps[gt_class[i][0], ytl, xtl] = 1
+ br_heatmaps[gt_class[i][0], ybr, xbr] = 1
+
+ tl_regrs[i, :] = [fxtl - xtl, fytl - ytl]
+ br_regrs[i, :] = [fxbr - xbr, fybr - ybr]
+ tl_tags[tag_lens] = ytl * self.output_size[1] + xtl
+ br_tags[tag_lens] = ybr * self.output_size[1] + xbr
+ tag_lens += 1
+
+ tag_masks[:tag_lens] = 1
+
+ sample['tl_heatmaps'] = tl_heatmaps
+ sample['br_heatmaps'] = br_heatmaps
+ sample['tl_regrs'] = tl_regrs
+ sample['br_regrs'] = br_regrs
+ sample['tl_tags'] = tl_tags
+ sample['br_tags'] = br_tags
+ sample['tag_masks'] = tag_masks
+
+ return sample
+
+
+@register_op
+class CornerCrop(BaseOperator):
+ """
+ Random crop for CornerNet
+ Args:
+ random_scales (list): scales of output_size to input_size.
+ border (int): border of corp center
+ is_train (bool): train or test
+ input_size (int): size of input image
+ """
+
+ def __init__(self,
+ random_scales=[0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3],
+ border=128,
+ is_train=True,
+ input_size=511):
+ super(CornerCrop, self).__init__()
+ self.random_scales = random_scales
+ self.border = border
+ self.is_train = is_train
+ self.input_size = input_size
+
+ def __call__(self, sample, context=None):
+ im_h, im_w = int(sample['h']), int(sample['w'])
+ if self.is_train:
+ scale = np.random.choice(self.random_scales)
+ height = int(self.input_size * scale)
+ width = int(self.input_size * scale)
+
+ w_border = self._get_border(self.border, im_w)
+ h_border = self._get_border(self.border, im_h)
+
+ ctx = np.random.randint(low=w_border, high=im_w - w_border)
+ cty = np.random.randint(low=h_border, high=im_h - h_border)
+
+ else:
+ cty, ctx = im_h // 2, im_w // 2
+ height = im_h | 127
+ width = im_w | 127
+
+ cropped_image = np.zeros(
+ (height, width, 3), dtype=sample['image'].dtype)
+
+ x0, x1 = max(ctx - width // 2, 0), min(ctx + width // 2, im_w)
+ y0, y1 = max(cty - height // 2, 0), min(cty + height // 2, im_h)
+
+ left_w, right_w = ctx - x0, x1 - ctx
+ top_h, bottom_h = cty - y0, y1 - cty
+
+ # crop image
+ cropped_ctx, cropped_cty = width // 2, height // 2
+ x_slice = slice(int(cropped_ctx - left_w), int(cropped_ctx + right_w))
+ y_slice = slice(int(cropped_cty - top_h), int(cropped_cty + bottom_h))
+ cropped_image[y_slice, x_slice, :] = sample['image'][y0:y1, x0:x1, :]
+
+ sample['image'] = cropped_image
+ sample['h'], sample['w'] = height, width
+
+ if self.is_train:
+ # crop detections
+ gt_bbox = sample['gt_bbox']
+ gt_bbox[:, 0:4:2] -= x0
+ gt_bbox[:, 1:4:2] -= y0
+ gt_bbox[:, 0:4:2] += cropped_ctx - left_w
+ gt_bbox[:, 1:4:2] += cropped_cty - top_h
+ else:
+ sample['borders'] = np.array(
+ [
+ cropped_cty - top_h, cropped_cty + bottom_h,
+ cropped_ctx - left_w, cropped_ctx + right_w
+ ],
+ dtype=np.float32)
+
+ return sample
+
+ def _get_border(self, border, size):
+ i = 1
+ while size - border // i <= border // i:
+ i *= 2
+ return border // i
+
+
+@register_op
+class CornerRatio(BaseOperator):
+ """
+ Ratio of output size to image size
+ Args:
+ input_size (int): the size of input size
+ output_size (int): the size of heatmap
+ """
+
+ def __init__(self, input_size=511, output_size=64):
+ super(CornerRatio, self).__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+
+ def __call__(self, sample, context=None):
+ scale = (self.input_size + 1) // self.output_size
+ out_height, out_width = (sample['h'] + 1) // scale, (
+ sample['w'] + 1) // scale
+ height_ratio = out_height / float(sample['h'])
+ width_ratio = out_width / float(sample['w'])
+ sample['ratios'] = np.array([height_ratio, width_ratio])
+
+ return sample
diff --git a/ppdet/ext_op/README.md b/ppdet/ext_op/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc6173b36c3d70d4e30706da12b32523dcc6e00c
--- /dev/null
+++ b/ppdet/ext_op/README.md
@@ -0,0 +1,63 @@
+# 自定义OP的编译过程
+
+**注意:** 编译自定义OP使用的gcc版本须与Paddle编译使用gcc版本一致,Paddle develop每日版本目前采用**gcc 4.8.2**版本编译,若使用每日版本,请使用**gcc 4.8.2**版本编译自定义OP,否则可能出现兼容性问题。
+
+## 代码结构
+
+ - src: 扩展OP C++/CUDA 源码
+ - cornerpool_lib.py: Python API封装
+ - tests: 各OP单测程序
+
+
+## 编译自定义OP
+
+自定义op需要将实现的C++、CUDA代码编译成动态库,```src/mask.sh```中通过g++/nvcc编译,当然您也可以写Makefile或者CMake。
+
+编译需要include PaddlePaddle的相关头文件,链接PaddlePaddle的lib库。 头文件和lib库可通过下面命令获取到:
+
+```
+# python
+>>> import paddle
+>>> print(paddle.sysconfig.get_include())
+/paddle/pyenv/local/lib/python2.7/site-packages/paddle/include
+>>> print(paddle.sysconfig.get_lib())
+/paddle/pyenv/local/lib/python2.7/site-packages/paddle/libs
+```
+
+我们提供动态库编译脚本如下:
+
+```
+cd src
+sh make.sh
+```
+
+最终编译会产出`cornerpool_lib.so`
+
+**说明:** 若使用源码编译安装PaddlePaddle的方式,编译过程中`cmake`未设置`WITH_MKLDNN`的方式,
+编译自定义OP时会报错找不到`mkldnn.h`等文件,可在`make.sh`中删除编译命令中的`-DPADDLE_WITH_MKLDNN`选项。
+
+
+## 执行单测
+
+执行下列单测,确保自定义算子可在网络中正确使用:
+
+```
+# 回到 ext_op 目录,添加 PYTHONPATH
+cd ..
+export PYTHONPATH=$PYTHONPATH:`pwd`
+
+# 运行单测
+python test/test_corner_op.py
+```
+
+单测运行成功会输出提示信息,如下所示:
+
+```
+.
+----------------------------------------------------------------------
+Ran 4 test in 2.858s
+
+OK
+```
+
+更多关于如何在框架外部自定义 C++ OP,可阅读[官网说明文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/index_cn.html)
diff --git a/ppdet/ext_op/__init__.py b/ppdet/ext_op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d38f757f866ce724d4e079484e45797a9e3bdb9
--- /dev/null
+++ b/ppdet/ext_op/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from . import cornerpool_lib
+from .cornerpool_lib import *
+
+__all__ = cornerpool_lib.__all__
diff --git a/ppdet/ext_op/cornerpool_lib.py b/ppdet/ext_op/cornerpool_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56fc661ad27f89df7540f05bdbd01035e06ec67
--- /dev/null
+++ b/ppdet/ext_op/cornerpool_lib.py
@@ -0,0 +1,189 @@
+import os
+import paddle.fluid as fluid
+
+file_dir = os.path.dirname(os.path.abspath(__file__))
+fluid.load_op_library(os.path.join(file_dir, 'src/cornerpool_lib.so'))
+
+from paddle.fluid.layer_helper import LayerHelper
+
+__all__ = [
+ 'bottom_pool',
+ 'top_pool',
+ 'right_pool',
+ 'left_pool',
+]
+
+
+def bottom_pool(input, is_test=False, name=None):
+ """
+ This layer calculates the bottom pooling output based on the input.
+ Scan the input from top to bottm for the vertical max-pooling.
+ The output has the same shape with input.
+ Args:
+ input(Variable): This input is a Tensor with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Returns:
+ Variable(Tensor): The output of bottom_pool, with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Examples:
+ ..code-block:: python
+ import paddle.fluid as fluid
+ import cornerpool_lib
+ input = fluid.data(
+ name='input', shape=[2, 64, 10, 10], dtype='float32')
+ output = corner_pool.bottom_pool(input)
+ """
+ if is_test:
+ helper = LayerHelper('bottom_pool', **locals())
+ dtype = helper.input_dtype()
+ output = helper.create_variable_for_type_inference(dtype)
+ max_map = helper.create_variable_for_type_inference(dtype)
+ helper.append_op(
+ type="bottom_pool",
+ inputs={"X": input},
+ outputs={"Output": output,
+ "MaxMap": max_map})
+ return output
+ H = input.shape[2]
+ i = 1
+ output = input
+ while i < H:
+ cur = output[:, :, i:, :]
+ next = output[:, :, :H - i, :]
+ max_v = fluid.layers.elementwise_max(cur, next)
+ output = fluid.layers.concat([output[:, :, :i, :], max_v], axis=2)
+ i *= 2
+
+ return output
+
+
+def top_pool(input, is_test=False, name=None):
+ """
+ This layer calculates the top pooling output based on the input.
+ Scan the input from bottom to top for the vertical max-pooling.
+ The output has the same shape with input.
+ Args:
+ input(Variable): This input is a Tensor with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Returns:
+ Variable(Tensor): The output of top_pool, with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Examples:
+ ..code-block:: python
+ import paddle.fluid as fluid
+ import cornerpool_lib
+ input = fluid.data(
+ name='input', shape=[2, 64, 10, 10], dtype='float32')
+ output = corner_pool.top_pool(input)
+ """
+ if is_test:
+ helper = LayerHelper('top_pool', **locals())
+ dtype = helper.input_dtype()
+ output = helper.create_variable_for_type_inference(dtype)
+ max_map = helper.create_variable_for_type_inference(dtype)
+ helper.append_op(
+ type="top_pool",
+ inputs={"X": input},
+ outputs={"Output": output,
+ "MaxMap": max_map})
+ return output
+
+ H = input.shape[2]
+ i = 1
+ output = input
+ while i < H:
+ cur = output[:, :, :H - i, :]
+ next = output[:, :, i:, :]
+ max_v = fluid.layers.elementwise_max(cur, next)
+ output = fluid.layers.concat([max_v, output[:, :, H - i:, :]], axis=2)
+ i *= 2
+
+ return output
+
+
+def right_pool(input, is_test=False, name=None):
+ """
+ This layer calculates the right pooling output based on the input.
+ Scan the input from left to right for the horizontal max-pooling.
+ The output has the same shape with input.
+ Args:
+ input(Variable): This input is a Tensor with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Returns:
+ Variable(Tensor): The output of right_pool, with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Examples:
+ ..code-block:: python
+ import paddle.fluid as fluid
+ import cornerpool_lib
+ input = fluid.data(
+ name='input', shape=[2, 64, 10, 10], dtype='float32')
+ output = corner_pool.right_pool(input)
+ """
+ if is_test:
+ helper = LayerHelper('right_pool', **locals())
+ dtype = helper.input_dtype()
+ output = helper.create_variable_for_type_inference(dtype)
+ max_map = helper.create_variable_for_type_inference(dtype)
+ helper.append_op(
+ type="right_pool",
+ inputs={"X": input},
+ outputs={"Output": output,
+ "MaxMap": max_map})
+ return output
+
+ W = input.shape[3]
+ i = 1
+ output = input
+ while i < W:
+ cur = output[:, :, :, i:]
+ next = output[:, :, :, :W - i]
+ max_v = fluid.layers.elementwise_max(cur, next)
+ output = fluid.layers.concat([output[:, :, :, :i], max_v], axis=-1)
+ i *= 2
+
+ return output
+
+
+def left_pool(input, is_test=False, name=None):
+ """
+ This layer calculates the left pooling output based on the input.
+ Scan the input from right to left for the horizontal max-pooling.
+ The output has the same shape with input.
+ Args:
+ input(Variable): This input is a Tensor with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Returns:
+ Variable(Tensor): The output of left_pool, with shape [N, C, H, W].
+ The data type is float32 or float64.
+ Examples:
+ ..code-block:: python
+ import paddle.fluid as fluid
+ import cornerpool_lib
+ input = fluid.data(
+ name='input', shape=[2, 64, 10, 10], dtype='float32')
+ output = corner_pool.left_pool(input)
+ """
+ if is_test:
+ helper = LayerHelper('left_pool', **locals())
+ dtype = helper.input_dtype()
+ output = helper.create_variable_for_type_inference(dtype)
+ max_map = helper.create_variable_for_type_inference(dtype)
+ helper.append_op(
+ type="left_pool",
+ inputs={"X": input},
+ outputs={"Output": output,
+ "MaxMap": max_map})
+ return output
+
+ W = input.shape[3]
+ i = 1
+ output = input
+ while i < W:
+ cur = output[:, :, :, :W - i]
+ next = output[:, :, :, i:]
+ max_v = fluid.layers.elementwise_max(cur, next)
+ output = fluid.layers.concat([max_v, output[:, :, :, W - i:]], axis=-1)
+ i *= 2
+
+ return output
diff --git a/ppdet/ext_op/src/bottom_pool_op.cc b/ppdet/ext_op/src/bottom_pool_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..6a867d1f127a34a9f93a0d0719a5dba039466f24
--- /dev/null
+++ b/ppdet/ext_op/src/bottom_pool_op.cc
@@ -0,0 +1,101 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+class BottomPoolOp : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ ctx->ShareDim("X", /*->*/ "MaxMap");
+ ctx->ShareDim("X", /*->*/ "Output");
+ }
+
+protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(ctx.Input("X")->type(),
+ ctx.GetPlace());
+ }
+};
+
+class BottomPoolOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+ void Make() override {
+ AddInput("X",
+ "Input with shape (batch, C, H, W)");
+ AddOutput("MaxMap", "Max map with index of maximum value of input");
+ AddOutput("Output", "output with same shape as input(X)");
+ AddComment(
+ R"Doc(
+This operatio calculates the bottom pooling output based on the input.
+Scan the input from top to bottom for the vertical max-pooling.
+The output has the same shape with input.
+ )Doc");
+ }
+};
+
+class BottomPoolOpGrad : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+protected:
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
+ "Input(Output@GRAD) should not be null");
+ auto out_grad_name = framework::GradVarName("Output");
+ ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
+ }
+
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ ctx.Input(framework::GradVarName("Output"))->type(),
+ ctx.GetPlace());
+ }
+};
+
+template
+class BottomPoolGradDescMaker : public framework::SingleGradOpMaker {
+public:
+ using framework::SingleGradOpMaker::SingleGradOpMaker;
+
+protected:
+ void Apply(GradOpPtr op) const override {
+ op->SetType("bottom_pool_grad");
+ op->SetInput("X", this->Input("X"));
+ op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
+ op->SetInput("MaxMap", this->Output("MaxMap"));
+ op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
+ op->SetAttrMap(this->Attrs());
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(bottom_pool,
+ ops::BottomPoolOp,
+ ops::BottomPoolOpMaker,
+ ops::BottomPoolGradDescMaker,
+ ops::BottomPoolGradDescMaker);
+REGISTER_OPERATOR(bottom_pool_grad, ops::BottomPoolOpGrad);
diff --git a/ppdet/ext_op/src/bottom_pool_op.cu b/ppdet/ext_op/src/bottom_pool_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..4912ec3c0effb2d924111203168da821aae16b19
--- /dev/null
+++ b/ppdet/ext_op/src/bottom_pool_op.cu
@@ -0,0 +1,104 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/memory/memory.h"
+#include
+#include "util.cu.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+static constexpr int kNumCUDAThreads = 512;
+static constexpr int kNumMaximumNumBlocks = 4096;
+
+static inline int NumBlocks(const int N) {
+ return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
+ kNumMaximumNumBlocks);
+}
+
+template
+class BottomPoolOpCUDAKernel : public framework::OpKernel {
+public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "This kernel only runs on GPU device.");
+ auto *x = ctx.Input("X");
+ auto *max_map = ctx.Output("MaxMap");
+ auto *output = ctx.Output("Output");
+ auto *x_data = x->data();
+ auto x_dims = x->dims();
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int num = x->numel();
+ auto& dev_ctx = ctx.cuda_device_context();
+
+ int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace());
+ T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int blocks = NumBlocks(num / height);
+
+ auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T));
+ T* max_val_data = reinterpret_cast(max_val_ptr->ptr());
+ auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int));
+ int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr());
+
+ GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, false, max_val_data, max_ind_data, max_map_data);
+
+ blocks = NumBlocks(num);
+ ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data);
+ }
+};
+
+template
+class BottomPoolGradOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* x = ctx.Input("X");
+ auto* max_map = ctx.Input("MaxMap");
+ auto* out_grad = ctx.Input(framework::GradVarName("Output"));
+ auto* in_grad = ctx.Output(framework::GradVarName("X"));
+ auto x_dims = x->dims();
+
+ auto& dev_ctx = ctx.cuda_device_context();
+ T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int grad_num = in_grad->numel();
+ int blocks = NumBlocks(grad_num);
+ FillConstant<<>>(in_grad_data, 0, grad_num);
+
+ ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(bottom_pool,
+ ops::BottomPoolOpCUDAKernel,
+ ops::BottomPoolOpCUDAKernel);
+REGISTER_OP_CUDA_KERNEL(bottom_pool_grad,
+ ops::BottomPoolGradOpCUDAKernel,
+ ops::BottomPoolGradOpCUDAKernel);
diff --git a/ppdet/ext_op/src/left_pool_op.cc b/ppdet/ext_op/src/left_pool_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..c2a8f169fe0f11b7cebfdb16fb38e111a820df48
--- /dev/null
+++ b/ppdet/ext_op/src/left_pool_op.cc
@@ -0,0 +1,101 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+class LeftPoolOp : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ ctx->ShareDim("X", /*->*/ "MaxMap");
+ ctx->ShareDim("X", /*->*/ "Output");
+ }
+
+protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(ctx.Input("X")->type(),
+ ctx.GetPlace());
+ }
+};
+
+class LeftPoolOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+ void Make() override {
+ AddInput("X",
+ "Input with shape (batch, C, H, W)");
+ AddOutput("MaxMap", "Max map with index of maximum value of input");
+ AddOutput("Output", "output with same shape as input(X)");
+ AddComment(
+ R"Doc(
+This operatio calculates the left pooling output based on the input.
+Scan the input from right to left for the horizontal max-pooling.
+The output has the same shape with input.
+ )Doc");
+ }
+};
+
+class LeftPoolOpGrad : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+protected:
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
+ "Input(Output@GRAD) should not be null");
+ auto out_grad_name = framework::GradVarName("Output");
+ ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
+ }
+
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ ctx.Input(framework::GradVarName("Output"))->type(),
+ ctx.GetPlace());
+ }
+};
+
+template
+class LeftPoolGradDescMaker : public framework::SingleGradOpMaker {
+public:
+ using framework::SingleGradOpMaker::SingleGradOpMaker;
+
+protected:
+ void Apply(GradOpPtr op) const override {
+ op->SetType("left_pool_grad");
+ op->SetInput("X", this->Input("X"));
+ op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
+ op->SetInput("MaxMap", this->Output("MaxMap"));
+ op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
+ op->SetAttrMap(this->Attrs());
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(left_pool,
+ ops::LeftPoolOp,
+ ops::LeftPoolOpMaker,
+ ops::LeftPoolGradDescMaker,
+ ops::LeftPoolGradDescMaker);
+REGISTER_OPERATOR(left_pool_grad, ops::LeftPoolOpGrad);
diff --git a/ppdet/ext_op/src/left_pool_op.cu b/ppdet/ext_op/src/left_pool_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a5e9323adc6e268bf6572cf5470dec841664b6ec
--- /dev/null
+++ b/ppdet/ext_op/src/left_pool_op.cu
@@ -0,0 +1,106 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/memory/memory.h"
+#include
+#include "util.cu.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+static constexpr int kNumCUDAThreads = 512;
+static constexpr int kNumMaximumNumBlocks = 4096;
+
+static inline int NumBlocks(const int N) {
+ return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
+ kNumMaximumNumBlocks);
+}
+
+template
+class LeftPoolOpCUDAKernel : public framework::OpKernel {
+public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "This kernel only runs on GPU device.");
+ auto *x = ctx.Input("X");
+ auto *max_map = ctx.Output("MaxMap");
+ auto *output = ctx.Output("Output");
+ auto *x_data = x->data();
+ auto x_dims = x->dims();
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int num = x->numel();
+ auto& dev_ctx = ctx.cuda_device_context();
+
+ int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace());
+ T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int blocks = NumBlocks(num / width);
+
+ auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T));
+ T* max_val_data = reinterpret_cast(max_val_ptr->ptr());
+ auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int));
+ int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr());
+
+ GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, true, max_val_data, max_ind_data, max_map_data);
+
+ blocks = NumBlocks(num);
+ ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data);
+
+ }
+};
+
+template
+class LeftPoolGradOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* x = ctx.Input("X");
+ auto* max_map = ctx.Input("MaxMap");
+ auto* out_grad = ctx.Input(framework::GradVarName("Output"));
+ auto* in_grad = ctx.Output(framework::GradVarName("X"));
+ auto x_dims = x->dims();
+
+ auto& dev_ctx = ctx.cuda_device_context();
+ T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int grad_num = in_grad->numel();
+ int blocks = NumBlocks(grad_num);
+ FillConstant<<>>(in_grad_data, 0, grad_num);
+
+ ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data);
+ }
+};
+
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(left_pool,
+ ops::LeftPoolOpCUDAKernel,
+ ops::LeftPoolOpCUDAKernel);
+REGISTER_OP_CUDA_KERNEL(left_pool_grad,
+ ops::LeftPoolGradOpCUDAKernel,
+ ops::LeftPoolGradOpCUDAKernel);
diff --git a/ppdet/ext_op/src/make.sh b/ppdet/ext_op/src/make.sh
new file mode 100755
index 0000000000000000000000000000000000000000..bd0d3a3c904bd9409d200100de6d8e30ca974504
--- /dev/null
+++ b/ppdet/ext_op/src/make.sh
@@ -0,0 +1,23 @@
+include_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_include())' )
+lib_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_lib())' )
+
+echo $include_dir
+echo $lib_dir
+
+OPS='bottom_pool_op top_pool_op right_pool_op left_pool_op'
+for op in ${OPS}
+do
+nvcc ${op}.cu -c -o ${op}.cu.o -ccbin cc -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -DPADDLE_WITH_MKLDNN -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O0 -g -DNVCC \
+ -I ${include_dir}/third_party/ \
+ -I ${include_dir}
+done
+
+g++ bottom_pool_op.cc bottom_pool_op.cu.o top_pool_op.cc top_pool_op.cu.o right_pool_op.cc right_pool_op.cu.o left_pool_op.cc left_pool_op.cu.o -o cornerpool_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \
+ -I ${include_dir}/third_party/ \
+ -I ${include_dir} \
+ -L ${lib_dir} \
+ -L /usr/local/cuda/lib64 -lpaddle_framework -lcudart
+
+rm *.cu.o
+
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
diff --git a/ppdet/ext_op/src/right_pool_op.cc b/ppdet/ext_op/src/right_pool_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..6bf74a1b08878724e388ae66ffe34c1ae0c65ec8
--- /dev/null
+++ b/ppdet/ext_op/src/right_pool_op.cc
@@ -0,0 +1,101 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+class RightPoolOp : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ ctx->ShareDim("X", /*->*/ "MaxMap");
+ ctx->ShareDim("X", /*->*/ "Output");
+ }
+
+protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(ctx.Input("X")->type(),
+ ctx.GetPlace());
+ }
+};
+
+class RightPoolOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+ void Make() override {
+ AddInput("X",
+ "Input with shape (batch, C, H, W)");
+ AddOutput("MaxMap", "Max map with index of maximum value of input");
+ AddOutput("Output", "output with same shape as input(X)");
+ AddComment(
+ R"Doc(
+This operatio calculates the right pooling output based on the input.
+Scan the input from left to right or the horizontal max-pooling.
+The output has the same shape with input.
+ )Doc");
+ }
+};
+
+class RightPoolOpGrad : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+protected:
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
+ "Input(Output@GRAD) should not be null");
+ auto out_grad_name = framework::GradVarName("Output");
+ ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
+ }
+
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ ctx.Input(framework::GradVarName("Output"))->type(),
+ ctx.GetPlace());
+ }
+};
+
+template
+class RightPoolGradDescMaker : public framework::SingleGradOpMaker {
+public:
+ using framework::SingleGradOpMaker::SingleGradOpMaker;
+
+protected:
+ void Apply(GradOpPtr op) const override {
+ op->SetType("right_pool_grad");
+ op->SetInput("X", this->Input("X"));
+ op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
+ op->SetInput("MaxMap", this->Output("MaxMap"));
+ op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
+ op->SetAttrMap(this->Attrs());
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(right_pool,
+ ops::RightPoolOp,
+ ops::RightPoolOpMaker,
+ ops::RightPoolGradDescMaker,
+ ops::RightPoolGradDescMaker);
+REGISTER_OPERATOR(right_pool_grad, ops::RightPoolOpGrad);
diff --git a/ppdet/ext_op/src/right_pool_op.cu b/ppdet/ext_op/src/right_pool_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..08a52ecf1eec9816c8a29e50452344d522b75c49
--- /dev/null
+++ b/ppdet/ext_op/src/right_pool_op.cu
@@ -0,0 +1,105 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/memory/memory.h"
+#include
+#include "util.cu.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+static constexpr int kNumCUDAThreads = 512;
+static constexpr int kNumMaximumNumBlocks = 4096;
+
+static inline int NumBlocks(const int N) {
+ return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
+ kNumMaximumNumBlocks);
+}
+
+template
+class RightPoolOpCUDAKernel : public framework::OpKernel {
+public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "This kernel only runs on GPU device.");
+ auto *x = ctx.Input("X");
+ auto *max_map = ctx.Output("MaxMap");
+ auto *output = ctx.Output("Output");
+ auto *x_data = x->data();
+ auto x_dims = x->dims();
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int num = x->numel();
+ auto& dev_ctx = ctx.cuda_device_context();
+
+ int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace());
+ T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int blocks = NumBlocks(num / width);
+
+ auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T));
+ T* max_val_data = reinterpret_cast(max_val_ptr->ptr());
+ auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int));
+ int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr());
+
+ GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, false, max_val_data, max_ind_data, max_map_data);
+
+ blocks = NumBlocks(num);
+ ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data);
+
+ }
+};
+
+template
+class RightPoolGradOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* x = ctx.Input("X");
+ auto* max_map = ctx.Input("MaxMap");
+ auto* out_grad = ctx.Input(framework::GradVarName("Output"));
+ auto* in_grad = ctx.Output(framework::GradVarName("X"));
+ auto x_dims = x->dims();
+
+ auto& dev_ctx = ctx.cuda_device_context();
+ T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int grad_num = in_grad->numel();
+ int blocks = NumBlocks(grad_num);
+ FillConstant<<>>(in_grad_data, 0, grad_num);
+
+ ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(right_pool,
+ ops::RightPoolOpCUDAKernel,
+ ops::RightPoolOpCUDAKernel);
+REGISTER_OP_CUDA_KERNEL(right_pool_grad,
+ ops::RightPoolGradOpCUDAKernel,
+ ops::RightPoolGradOpCUDAKernel);
diff --git a/ppdet/ext_op/src/top_pool_op.cc b/ppdet/ext_op/src/top_pool_op.cc
new file mode 100755
index 0000000000000000000000000000000000000000..29cba6660193c3bfe861dc81b03edbbb9368ea19
--- /dev/null
+++ b/ppdet/ext_op/src/top_pool_op.cc
@@ -0,0 +1,102 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+class TopPoolOp : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ ctx->ShareDim("X", /*->*/ "MaxMap");
+ ctx->ShareDim("X", /*->*/ "Output");
+ }
+
+protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(ctx.Input("X")->type(),
+ ctx.GetPlace());
+ }
+};
+
+class TopPoolOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+ void Make() override {
+ AddInput("X",
+ "Input with shape (batch, C, H, W)");
+ AddOutput("MaxMap", "Max map with index of maximum value of input");
+ AddOutput("Output", "Output with same shape as input(X)");
+ AddComment(
+ R"Doc(
+This operatio calculates the top pooling output based on the input.
+Scan the input from bottom to top for the vertical max-pooling.
+The output has the same shape with input.
+ )Doc");
+ }
+};
+
+class TopPoolOpGrad : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+protected:
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
+ "Input(Output@GRAD) should not be null");
+
+ auto out_grad_name = framework::GradVarName("Output");
+ ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
+ }
+
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ ctx.Input(framework::GradVarName("Output"))->type(),
+ ctx.GetPlace());
+ }
+};
+
+template
+class TopPoolGradDescMaker : public framework::SingleGradOpMaker {
+ public:
+ using framework::SingleGradOpMaker::SingleGradOpMaker;
+
+ protected:
+ void Apply(GradOpPtr op) const override {
+ op->SetType("top_pool_grad");
+ op->SetInput("X", this->Input("X"));
+ op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
+ op->SetInput("MaxMap", this->Output("MaxMap"));
+ op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
+ op->SetAttrMap(this->Attrs());
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(top_pool,
+ ops::TopPoolOp,
+ ops::TopPoolOpMaker,
+ ops::TopPoolGradDescMaker,
+ ops::TopPoolGradDescMaker);
+REGISTER_OPERATOR(top_pool_grad, ops::TopPoolOpGrad);
diff --git a/ppdet/ext_op/src/top_pool_op.cu b/ppdet/ext_op/src/top_pool_op.cu
new file mode 100755
index 0000000000000000000000000000000000000000..f6237fe798098ad0c6cb25597de849e65383ab78
--- /dev/null
+++ b/ppdet/ext_op/src/top_pool_op.cu
@@ -0,0 +1,104 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+GUnless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/memory/memory.h"
+#include
+#include "util.cu.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+static constexpr int kNumCUDAThreads = 512;
+static constexpr int kNumMaximumNumBlocks = 4096;
+
+static inline int NumBlocks(const int N) {
+ return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
+ kNumMaximumNumBlocks);
+}
+
+template
+class TopPoolOpCUDAKernel : public framework::OpKernel {
+public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "This kernel only runs on GPU device.");
+ auto *x = ctx.Input("X");
+ auto *max_map = ctx.Output("MaxMap");
+ auto *output = ctx.Output("Output");
+ auto *x_data = x->data();
+ auto x_dims = x->dims();
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int num = x->numel();
+ auto& dev_ctx = ctx.cuda_device_context();
+
+ int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace());
+ T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int blocks = NumBlocks(num / height);
+
+ auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T));
+ T* max_val_data = reinterpret_cast(max_val_ptr->ptr());
+ auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int));
+ int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr());
+
+ GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, true, max_val_data, max_ind_data, max_map_data);
+
+ blocks = NumBlocks(num);
+ ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data);
+ }
+};
+
+template
+class TopPoolGradOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* x = ctx.Input("X");
+ auto* max_map = ctx.Input("MaxMap");
+ auto* out_grad = ctx.Input(framework::GradVarName("Output"));
+ auto* in_grad = ctx.Output(framework::GradVarName("X"));
+ auto x_dims = x->dims();
+ auto& dev_ctx = ctx.cuda_device_context();
+ T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace());
+ auto gpu_place = boost::get(dev_ctx.GetPlace());
+
+ int threads = kNumCUDAThreads;
+ int NC_num = x_dims[0] * x_dims[1];
+ int height = x_dims[2];
+ int width = x_dims[3];
+ int grad_num = in_grad->numel();
+ int blocks = NumBlocks(grad_num);
+ FillConstant<<>>(in_grad_data, 0, grad_num);
+
+ ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(top_pool,
+ ops::TopPoolOpCUDAKernel,
+ ops::TopPoolOpCUDAKernel);
+REGISTER_OP_CUDA_KERNEL(top_pool_grad,
+ ops::TopPoolGradOpCUDAKernel,
+ ops::TopPoolGradOpCUDAKernel);
diff --git a/ppdet/ext_op/src/util.cu.h b/ppdet/ext_op/src/util.cu.h
new file mode 100644
index 0000000000000000000000000000000000000000..615e45a7891fca7f89eea5824b4691d33a1fe5ef
--- /dev/null
+++ b/ppdet/ext_op/src/util.cu.h
@@ -0,0 +1,223 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/memory/memory.h"
+#include
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+template
+__global__ void FillConstant(T* x, int num, int fill_num) {
+ CUDA_1D_KERNEL_LOOP(i, fill_num) {
+ x[i] = static_cast(num);
+ }
+}
+
+template
+__global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int W,
+ const int axis, const int start, const int end,
+ T* output) {
+ int HW_num = H * W;
+ int length = axis == 2 ? W : H;
+ int sliced_len = end - start;
+ int cur_HW_num = length * sliced_len;
+ // slice input on H or W (axis is 2 or 3)
+ CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) {
+ int NC_id = i / cur_HW_num;
+ int HW_id = i % cur_HW_num;
+ if (axis == 2){
+ output[i] = x[NC_id * HW_num + start * W + HW_id];
+ } else if (axis == 3) {
+ int col = HW_id % sliced_len;
+ int row = HW_id / sliced_len;
+ output[i] = x[NC_id * HW_num + row * W + start + col];
+ }
+ }
+}
+
+template
+__global__ void MaxOut(const T* input, const int next_ind, const int NC_num,
+ const int H, const int W, const int axis,
+ const int start, const int end, T* output) {
+ int HW_num = H * W;
+ int length = axis == 2 ? W : H;
+ T cur = static_cast(0.);
+ T next = static_cast(0.);
+ T max_v = static_cast(0.);
+ int sliced_len = end - start;
+ int cur_HW_num = length * sliced_len;
+ // compare cur and next and assign max values to output
+ CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) {
+ int NC_id = i / cur_HW_num;
+ int HW_id = i % cur_HW_num;
+
+ if (axis == 2){
+ cur = input[NC_id * HW_num + start * W + HW_id];
+ next = input[NC_id * HW_num + next_ind * W + HW_id];
+ max_v = cur > next ? cur : next;
+ output[NC_id * HW_num + start * W + HW_id] = max_v;
+ } else if (axis == 3) {
+ int col = HW_id % sliced_len;
+ int row = HW_id / sliced_len;
+ cur = input[NC_id * HW_num + row * W + start + col];
+ next = input[NC_id * HW_num + row * W + next_ind + col];
+ max_v = cur > next ? cur : next;
+ output[NC_id * HW_num + row * W + start + col] = max_v;
+ }
+ __syncthreads();
+ }
+}
+
+template
+__global__ void UpdateMaxInfo(const T* input, const int NC_num,
+ const int H, const int W, const int axis,
+ const int index, T* max_val, int* max_ind) {
+ int length = axis == 2 ? W : H;
+ int HW_num = H * W;
+ T val = static_cast(0.);
+ CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
+ int NC_id = i / length;
+ int length_id = i % length;
+ if (axis == 2) {
+ val = input[NC_id * HW_num + index * W + length_id];
+ } else if (axis == 3) {
+ val = input[NC_id * HW_num + length_id * W + index];
+ }
+ if (val > max_val[i]) {
+ max_val[i] = val;
+ max_ind[i] = index;
+ }
+ __syncthreads();
+ }
+}
+
+template
+__global__ void ScatterAddOnAxis(const T* input, const int start, const int* max_ind, const int NC_num, const int H, const int W, const int axis, T* output) {
+ int length = axis == 2 ? W : H;
+ int HW_num = H * W;
+ CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
+ int NC_id = i / length;
+ int length_id = i % length;
+ int id_ = max_ind[i];
+ if (axis == 2) {
+ platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id, input[NC_id * HW_num + start * W + length_id]);
+ //output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num + start * W + length_id];
+ } else if (axis == 3) {
+ platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_, input[NC_id * HW_num + length_id * W + start]);
+ //output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num + length_id * W + start];
+ }
+ __syncthreads();
+ }
+}
+
+template
+__global__ void GetMaxInfo(const T* input, const int NC_num,
+ const int H, const int W, const int axis,
+ const bool reverse, T* max_val, int* max_ind,
+ int* max_map) {
+ int start = 0;
+ int end = axis == 2 ? H: W;
+ int s = reverse ? end-1 : start;
+ int e = reverse ? start-1 : end;
+ int step = reverse ? -1 : 1;
+ int len = axis == 2 ? W : H;
+ int loc = 0;
+ T val = static_cast(0.);
+ for (int i = s; ; ) {
+ if (i == s) {
+ CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
+ int NC_id = j / len;
+ int len_id = j % len;
+ if (axis == 2) {
+ loc = NC_id * H * W + i * W + len_id;
+ } else if (axis == 3){
+ loc = NC_id * H * W + len_id * W + i;
+ }
+ max_ind[j] = i;
+ max_map[loc] = max_ind[j];
+ max_val[j] = input[loc];
+ __syncthreads();
+ }
+ } else {
+ CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
+ int NC_id = j / len;
+ int len_id = j % len;
+
+ if (axis == 2) {
+ loc = NC_id * H * W + i * W + len_id;
+ } else if (axis == 3){
+ loc = NC_id * H * W + len_id * W + i;
+ }
+ val = input[loc];
+ T max_v = max_val[j];
+ if (val > max_v) {
+ max_val[j] = val;
+ max_map[loc] = i;
+ max_ind[j] = i;
+ } else {
+ max_map[loc] = max_ind[j];
+ }
+ __syncthreads();
+ }
+ }
+ i += step;
+ if (s < e && i >= e) break;
+ if (s > e && i <= e) break;
+ }
+}
+
+template
+__global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){
+ CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) {
+ int loc = max_map[i];
+ int NC_id = i / (H * W);
+ int len_id = 0;
+ if (axis == 2) {
+ len_id = i % W;
+ output[i] = input[NC_id * H * W + loc * W + len_id];
+ } else {
+ len_id = i % (H * W) / W;
+ output[i] = input[NC_id * H * W + len_id * W + loc];
+ }
+ }
+}
+
+template
+__global__ void ScatterAddBw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){
+ CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) {
+ int loc = max_map[i];
+ int NC_id = i / (H * W);
+ int len_id = 0;
+ int offset = 0;
+ if (axis == 2) {
+ len_id = i % W;
+ offset = NC_id * H * W + loc * W + len_id;
+ } else {
+ len_id = i % (H * W) / W;
+ offset = NC_id * H * W + len_id * W + loc;
+ }
+ platform::CudaAtomicAdd(output + offset, input[i]);
+ }
+}
+
+} // namespace operators
+} // namespace paddle
diff --git a/ppdet/ext_op/test/test_corner_pool.py b/ppdet/ext_op/test/test_corner_pool.py
new file mode 100755
index 0000000000000000000000000000000000000000..ee5e6b07d91e21b65733a3bce9cd705e95d99cd1
--- /dev/null
+++ b/ppdet/ext_op/test/test_corner_pool.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import numpy as np
+import paddle.fluid as fluid
+import cornerpool_lib
+
+
+def bottom_pool_np(x):
+ height = x.shape[2]
+ output = x.copy()
+ for ind in range(height):
+ cur = output[:, :, ind:height, :]
+ next = output[:, :, :height - ind, :]
+ output[:, :, ind:height, :] = np.maximum(cur, next)
+ return output
+
+
+def top_pool_np(x):
+ height = x.shape[2]
+ output = x.copy()
+ for ind in range(height):
+ cur = output[:, :, :height - ind, :]
+ next = output[:, :, ind:height, :]
+ output[:, :, :height - ind, :] = np.maximum(cur, next)
+ return output
+
+
+def right_pool_np(x):
+ width = x.shape[3]
+ output = x.copy()
+ for ind in range(width):
+ cur = output[:, :, :, ind:width]
+ next = output[:, :, :, :width - ind]
+ output[:, :, :, ind:width] = np.maximum(cur, next)
+ return output
+
+
+def left_pool_np(x):
+ width = x.shape[3]
+ output = x.copy()
+ for ind in range(width):
+ cur = output[:, :, :, :width - ind]
+ next = output[:, :, :, ind:width]
+ output[:, :, :, :width - ind] = np.maximum(cur, next)
+ return output
+
+
+class TestRightPoolOp(unittest.TestCase):
+ def funcmap(self):
+ self.func_map = {
+ 'bottom_x': [cornerpool_lib.bottom_pool, bottom_pool_np],
+ 'top_x': [cornerpool_lib.top_pool, top_pool_np],
+ 'right_x': [cornerpool_lib.right_pool, right_pool_np],
+ 'left_x': [cornerpool_lib.left_pool, left_pool_np]
+ }
+
+ def setup(self):
+ self.name = 'right_x'
+
+ def test_check_output(self):
+ self.funcmap()
+ self.setup()
+ x_shape = (2, 10, 16, 16)
+ x_type = "float64"
+
+ sp = fluid.Program()
+ tp = fluid.Program()
+ place = fluid.CUDAPlace(0)
+
+ with fluid.program_guard(tp, sp):
+ x = fluid.layers.data(
+ name=self.name,
+ shape=x_shape,
+ dtype=x_type,
+ append_batch_size=False)
+ y = self.func_map[self.name][0](x)
+
+ np.random.seed(0)
+ x_np = np.random.uniform(-1000, 1000, x_shape).astype(x_type)
+
+ out_np = self.func_map[self.name][1](x_np)
+
+ exe = fluid.Executor(place)
+ outs = exe.run(tp, feed={self.name: x_np}, fetch_list=[y])
+
+ self.assertTrue(np.allclose(outs, out_np))
+
+
+class TestTopPoolOp(TestRightPoolOp):
+ def setup(self):
+ self.name = 'top_x'
+
+
+class TestBottomPoolOp(TestRightPoolOp):
+ def setup(self):
+ self.name = 'bottom_x'
+
+
+class TestLeftPoolOp(TestRightPoolOp):
+ def setup(self):
+ self.name = 'left_x'
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/ppdet/modeling/anchor_heads/__init__.py b/ppdet/modeling/anchor_heads/__init__.py
index c6e495598fb7b7103e55f95f6acef00fdef98e00..49640bdf2c02d4e903aaa5ea5b6042ca52aa6623 100644
--- a/ppdet/modeling/anchor_heads/__init__.py
+++ b/ppdet/modeling/anchor_heads/__init__.py
@@ -18,8 +18,10 @@ from . import rpn_head
from . import yolo_head
from . import retina_head
from . import fcos_head
+from . import corner_head
from .rpn_head import *
from .yolo_head import *
from .retina_head import *
from .fcos_head import *
+from .corner_head import *
diff --git a/ppdet/modeling/anchor_heads/corner_head.py b/ppdet/modeling/anchor_heads/corner_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc7c64c19d742515a0edee6a6281e54c490b9ae3
--- /dev/null
+++ b/ppdet/modeling/anchor_heads/corner_head.py
@@ -0,0 +1,486 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import fluid
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import Constant
+
+from ..backbones.hourglass import _conv_norm, kaiming_init
+from ppdet.core.workspace import register
+import numpy as np
+try:
+ import cornerpool_lib
+except:
+ print(
+ "warning: cornerpool_lib not found, compile in ext_op at first if needed"
+ )
+
+__all__ = ['CornerHead']
+
+
+def corner_output(x, pool1, pool2, dim, name=None):
+ p_conv1 = fluid.layers.conv2d(
+ pool1 + pool2,
+ filter_size=3,
+ num_filters=dim,
+ padding=1,
+ param_attr=ParamAttr(
+ name=name + "_p_conv1_weight",
+ initializer=kaiming_init(pool1 + pool2, 3)),
+ bias_attr=False,
+ name=name + '_p_conv1')
+ p_bn1 = fluid.layers.batch_norm(
+ p_conv1,
+ param_attr=ParamAttr(name=name + '_p_bn1_weight'),
+ bias_attr=ParamAttr(name=name + '_p_bn1_bias'),
+ moving_mean_name=name + '_p_bn1_running_mean',
+ moving_variance_name=name + '_p_bn1_running_var',
+ name=name + '_p_bn1')
+
+ conv1 = fluid.layers.conv2d(
+ x,
+ filter_size=1,
+ num_filters=dim,
+ param_attr=ParamAttr(
+ name=name + "_conv1_weight", initializer=kaiming_init(x, 1)),
+ bias_attr=False,
+ name=name + '_conv1')
+ bn1 = fluid.layers.batch_norm(
+ conv1,
+ param_attr=ParamAttr(name=name + '_bn1_weight'),
+ bias_attr=ParamAttr(name=name + '_bn1_bias'),
+ moving_mean_name=name + '_bn1_running_mean',
+ moving_variance_name=name + '_bn1_running_var',
+ name=name + '_bn1')
+
+ relu1 = fluid.layers.relu(p_bn1 + bn1)
+ conv2 = _conv_norm(
+ relu1, 3, dim, pad=1, bn_act='relu', name=name + '_conv2')
+ return conv2
+
+
+def corner_pool(x, dim, pool1, pool2, is_test=False, name=None):
+ p1_conv1 = _conv_norm(
+ x, 3, 128, pad=1, bn_act='relu', name=name + '_p1_conv1')
+ pool1 = pool1(p1_conv1, is_test=is_test, name=name + '_pool1')
+ p2_conv1 = _conv_norm(
+ x, 3, 128, pad=1, bn_act='relu', name=name + '_p2_conv1')
+ pool2 = pool2(p2_conv1, is_test=is_test, name=name + '_pool2')
+
+ conv2 = corner_output(x, pool1, pool2, dim, name)
+ return conv2
+
+
+def gather_feat(feat, ind, batch_size=1):
+ feats = []
+ for bind in range(batch_size):
+ feat_b = feat[bind]
+ ind_b = ind[bind]
+ ind_b.stop_gradient = True
+ feat_bg = fluid.layers.gather(feat_b, ind_b)
+ feats.append(fluid.layers.unsqueeze(feat_bg, axes=[0]))
+ feat_g = fluid.layers.concat(feats, axis=0)
+ return feat_g
+
+
+def mask_feat(feat, ind, batch_size=1):
+ feat_t = fluid.layers.transpose(feat, [0, 2, 3, 1])
+ C = feat_t.shape[3]
+ feat_r = fluid.layers.reshape(feat_t, [0, -1, C])
+ return gather_feat(feat_r, ind, batch_size)
+
+
+def nms(heat):
+ hmax = fluid.layers.pool2d(heat, pool_size=3, pool_padding=1)
+ keep = fluid.layers.cast(heat == hmax, 'float32')
+ return heat * keep
+
+
+def _topk(scores, batch_size, height, width, K):
+ scores_r = fluid.layers.reshape(scores, [batch_size, -1])
+ topk_scores, topk_inds = fluid.layers.topk(scores_r, K)
+ topk_clses = topk_inds / (height * width)
+ topk_inds = topk_inds % (height * width)
+ topk_ys = fluid.layers.cast(topk_inds / width, 'float32')
+ topk_xs = fluid.layers.cast(topk_inds % width, 'float32')
+ return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
+
+
+def filter_scores(scores, index_list):
+ for ind in index_list:
+ tmp = scores * fluid.layers.cast((1 - ind), 'float32')
+ scores = tmp - fluid.layers.cast(ind, 'float32')
+ return scores
+
+
+def decode(tl_heat,
+ br_heat,
+ tl_tag,
+ br_tag,
+ tl_regr,
+ br_regr,
+ ae_threshold=1,
+ num_dets=1000,
+ K=100,
+ batch_size=1):
+ shape = fluid.layers.shape(tl_heat)
+ H, W = shape[2], shape[3]
+
+ tl_heat = fluid.layers.sigmoid(tl_heat)
+ br_heat = fluid.layers.sigmoid(br_heat)
+
+ tl_heat_nms = nms(tl_heat)
+ br_heat_nms = nms(br_heat)
+
+ tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat_nms, batch_size,
+ H, W, K)
+ br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat_nms, batch_size,
+ H, W, K)
+ tl_ys = fluid.layers.expand(
+ fluid.layers.reshape(tl_ys, [-1, K, 1]), [1, 1, K])
+ tl_xs = fluid.layers.expand(
+ fluid.layers.reshape(tl_xs, [-1, K, 1]), [1, 1, K])
+ br_ys = fluid.layers.expand(
+ fluid.layers.reshape(br_ys, [-1, 1, K]), [1, K, 1])
+ br_xs = fluid.layers.expand(
+ fluid.layers.reshape(br_xs, [-1, 1, K]), [1, K, 1])
+
+ tl_regr = mask_feat(tl_regr, tl_inds, batch_size)
+ br_regr = mask_feat(br_regr, br_inds, batch_size)
+ tl_regr = fluid.layers.reshape(tl_regr, [-1, K, 1, 2])
+ br_regr = fluid.layers.reshape(br_regr, [-1, 1, K, 2])
+
+ tl_xs = tl_xs + tl_regr[:, :, :, 0]
+ tl_ys = tl_ys + tl_regr[:, :, :, 1]
+ br_xs = br_xs + br_regr[:, :, :, 0]
+ br_ys = br_ys + br_regr[:, :, :, 1]
+
+ bboxes = fluid.layers.stack([tl_xs, tl_ys, br_xs, br_ys], axis=-1)
+
+ tl_tag = mask_feat(tl_tag, tl_inds, batch_size)
+ br_tag = mask_feat(br_tag, br_inds, batch_size)
+ tl_tag = fluid.layers.expand(
+ fluid.layers.reshape(tl_tag, [-1, K, 1]), [1, 1, K])
+ br_tag = fluid.layers.expand(
+ fluid.layers.reshape(br_tag, [-1, 1, K]), [1, K, 1])
+ dists = fluid.layers.abs(tl_tag - br_tag)
+
+ tl_scores = fluid.layers.expand(
+ fluid.layers.reshape(tl_scores, [-1, K, 1]), [1, 1, K])
+ br_scores = fluid.layers.expand(
+ fluid.layers.reshape(br_scores, [-1, 1, K]), [1, K, 1])
+ scores = (tl_scores + br_scores) / 2.
+
+ tl_clses = fluid.layers.expand(
+ fluid.layers.reshape(tl_clses, [-1, K, 1]), [1, 1, K])
+ br_clses = fluid.layers.expand(
+ fluid.layers.reshape(br_clses, [-1, 1, K]), [1, K, 1])
+ cls_inds = fluid.layers.cast(tl_clses != br_clses, 'int32')
+ dist_inds = fluid.layers.cast(dists > ae_threshold, 'int32')
+
+ width_inds = fluid.layers.cast(br_xs < tl_xs, 'int32')
+ height_inds = fluid.layers.cast(br_ys < tl_ys, 'int32')
+
+ scores = filter_scores(scores,
+ [cls_inds, dist_inds, width_inds, height_inds])
+ scores = fluid.layers.reshape(scores, [-1, K * K])
+
+ scores, inds = fluid.layers.topk(scores, num_dets)
+ scores = fluid.layers.reshape(scores, [-1, num_dets, 1])
+
+ bboxes = fluid.layers.reshape(bboxes, [batch_size, -1, 4])
+ bboxes = gather_feat(bboxes, inds, batch_size)
+
+ clses = fluid.layers.reshape(tl_clses, [batch_size, -1, 1])
+ clses = gather_feat(clses, inds, batch_size)
+
+ tl_scores = fluid.layers.reshape(tl_scores, [batch_size, -1, 1])
+ tl_scores = gather_feat(tl_scores, inds, batch_size)
+ br_scores = fluid.layers.reshape(br_scores, [batch_size, -1, 1])
+ br_scores = gather_feat(br_scores, inds, batch_size)
+
+ bboxes = fluid.layers.cast(bboxes, 'float32')
+ clses = fluid.layers.cast(clses, 'float32')
+ return bboxes, scores, tl_scores, br_scores, clses
+
+
+@register
+class CornerHead(object):
+ """
+ CornerNet head with corner_pooling
+
+ Args:
+ train_batch_size(int): batch_size in training process
+ test_batch_size(int): batch_size in test process, 1 by default
+ num_classes(int): num of classes, 80 by default
+ stack(int): stack of backbone, 2 by default
+ pull_weight(float): weight of pull_loss, 0.1 by default
+ push_weight(float): weight of push_loss, 0.1 by default
+ ae_threshold(float|int): threshold for valid distance of predicted tags, 1 by default
+ num_dets(int): num of detections, 1000 by default
+ top_k(int): choose top_k pair of corners in prediction, 100 by default
+ """
+ __shared__ = ['num_classes', 'stack']
+
+ def __init__(self,
+ train_batch_size,
+ test_batch_size=1,
+ num_classes=80,
+ stack=2,
+ pull_weight=0.1,
+ push_weight=0.1,
+ ae_threshold=1,
+ num_dets=1000,
+ top_k=100):
+ self.train_batch_size = train_batch_size
+ self.test_batch_size = test_batch_size
+ self.num_classes = num_classes
+ self.stack = stack
+ self.pull_weight = pull_weight
+ self.push_weight = push_weight
+ self.ae_threshold = ae_threshold
+ self.num_dets = num_dets
+ self.K = top_k
+ self.tl_heats = []
+ self.br_heats = []
+ self.tl_tags = []
+ self.br_tags = []
+ self.tl_offs = []
+ self.br_offs = []
+
+ def pred_mod(self, x, dim, name=None):
+ conv0 = _conv_norm(
+ x, 1, 256, with_bn=False, bn_act='relu', name=name + '_0')
+ conv1 = fluid.layers.conv2d(
+ input=conv0,
+ filter_size=1,
+ num_filters=dim,
+ param_attr=ParamAttr(
+ name=name + "_1_weight", initializer=kaiming_init(conv0, 1)),
+ bias_attr=ParamAttr(
+ name=name + "_1_bias", initializer=Constant(-2.19)),
+ name=name + '_1')
+ return conv1
+
+ def get_output(self, input):
+ for ind in range(self.stack):
+ cnv = input[ind]
+ tl_modules = corner_pool(
+ cnv,
+ 256,
+ cornerpool_lib.top_pool,
+ cornerpool_lib.left_pool,
+ name='tl_modules_' + str(ind))
+ br_modules = corner_pool(
+ cnv,
+ 256,
+ cornerpool_lib.bottom_pool,
+ cornerpool_lib.right_pool,
+ name='br_modules_' + str(ind))
+
+ tl_heat = self.pred_mod(
+ tl_modules, self.num_classes, name='tl_heats_' + str(ind))
+ br_heat = self.pred_mod(
+ br_modules, self.num_classes, name='br_heats_' + str(ind))
+
+ tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind))
+ br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind))
+
+ tl_off = self.pred_mod(tl_modules, 2, name='tl_offs_' + str(ind))
+ br_off = self.pred_mod(br_modules, 2, name='br_offs_' + str(ind))
+
+ self.tl_heats.append(tl_heat)
+ self.br_heats.append(br_heat)
+ self.tl_tags.append(tl_tag)
+ self.br_tags.append(br_tag)
+ self.tl_offs.append(tl_off)
+ self.br_offs.append(br_off)
+
+ def focal_loss(self, preds, gt, gt_masks):
+ preds_clip = []
+ none_pos = fluid.layers.cast(
+ fluid.layers.reduce_sum(gt_masks) == 0, 'float32')
+ none_pos.stop_gradient = True
+ min = fluid.layers.assign(np.array([1e-4], dtype='float32'))
+ max = fluid.layers.assign(np.array([1 - 1e-4], dtype='float32'))
+ for pred in preds:
+ pred_s = fluid.layers.sigmoid(pred)
+ pred_min = fluid.layers.elementwise_max(pred_s, min)
+ pred_max = fluid.layers.elementwise_min(pred_min, max)
+ preds_clip.append(pred_max)
+
+ ones = fluid.layers.ones_like(gt)
+
+ fg_map = fluid.layers.cast(gt == ones, 'float32')
+ fg_map.stop_gradient = True
+ num_pos = fluid.layers.reduce_sum(fg_map)
+ min_num = fluid.layers.ones_like(num_pos)
+ num_pos = fluid.layers.elementwise_max(num_pos, min_num)
+ num_pos.stop_gradient = True
+ bg_map = fluid.layers.cast(gt < ones, 'float32')
+ bg_map.stop_gradient = True
+ neg_weights = fluid.layers.pow(1 - gt, 4) * bg_map
+ neg_weights.stop_gradient = True
+ loss = fluid.layers.assign(np.array([0], dtype='float32'))
+ for ind, pred in enumerate(preds_clip):
+ pos_loss = fluid.layers.log(pred) * fluid.layers.pow(1 - pred,
+ 2) * fg_map
+
+ neg_loss = fluid.layers.log(1 - pred) * fluid.layers.pow(
+ pred, 2) * neg_weights
+
+ pos_loss = fluid.layers.reduce_sum(pos_loss)
+ neg_loss = fluid.layers.reduce_sum(neg_loss)
+ focal_loss_ = (neg_loss + pos_loss) / (num_pos + none_pos)
+ loss -= focal_loss_
+ return loss
+
+ def ae_loss(self, tl_tag, br_tag, gt_masks):
+ num = fluid.layers.reduce_sum(gt_masks, dim=1)
+ num_stop_gradient = True
+ tag0 = fluid.layers.squeeze(tl_tag, [2])
+ tag1 = fluid.layers.squeeze(br_tag, [2])
+ tag_mean = (tag0 + tag1) / 2
+
+ tag0 = fluid.layers.pow(tag0 - tag_mean, 2) / (num + 1e-4) * gt_masks
+ tag1 = fluid.layers.pow(tag1 - tag_mean, 2) / (num + 1e-4) * gt_masks
+ tag0 = fluid.layers.reduce_sum(tag0)
+ tag1 = fluid.layers.reduce_sum(tag1)
+
+ pull = tag0 + tag1
+
+ mask_1 = fluid.layers.expand(
+ fluid.layers.unsqueeze(gt_masks, [1]), [1, gt_masks.shape[1], 1])
+ mask_2 = fluid.layers.expand(
+ fluid.layers.unsqueeze(gt_masks, [2]), [1, 1, gt_masks.shape[1]])
+ mask = fluid.layers.cast((mask_1 + mask_2) == 2, 'float32')
+ mask.stop_gradient = True
+
+ num2 = (num - 1) * num
+ num2.stop_gradient = True
+ tag_mean_1 = fluid.layers.expand(
+ fluid.layers.unsqueeze(tag_mean, [1]), [1, tag_mean.shape[1], 1])
+ tag_mean_2 = fluid.layers.expand(
+ fluid.layers.unsqueeze(tag_mean, [2]), [1, 1, tag_mean.shape[1]])
+ dist = tag_mean_1 - tag_mean_2
+ dist = 1 - fluid.layers.abs(dist)
+ dist = fluid.layers.relu(dist)
+ dist = dist - 1 / (num + 1e-4)
+ dist = dist / (num2 + 1e-4)
+ dist = dist * mask
+ push = fluid.layers.reduce_sum(dist)
+ return pull, push
+
+ def off_loss(self, off, gt_off, gt_masks):
+ mask = fluid.layers.unsqueeze(gt_masks, [2])
+ mask = fluid.layers.expand_as(mask, gt_off)
+ mask.stop_gradient = True
+ off_loss = fluid.layers.smooth_l1(off, gt_off, mask, mask)
+ off_loss = fluid.layers.reduce_sum(off_loss)
+ total_num = fluid.layers.reduce_sum(gt_masks)
+ total_num.stop_gradient = True
+ return off_loss / (total_num + 1e-4)
+
+ def get_loss(self, targets):
+ gt_tl_heat = targets['tl_heatmaps']
+ gt_br_heat = targets['br_heatmaps']
+ gt_masks = targets['tag_masks']
+ gt_tl_off = targets['tl_regrs']
+ gt_br_off = targets['br_regrs']
+ gt_tl_ind = targets['tl_tags']
+ gt_br_ind = targets['br_tags']
+ gt_masks = fluid.layers.cast(gt_masks, 'float32')
+
+ focal_loss = 0
+ focal_loss_ = self.focal_loss(self.tl_heats, gt_tl_heat, gt_masks)
+ focal_loss += focal_loss_
+ focal_loss_ = self.focal_loss(self.br_heats, gt_br_heat, gt_masks)
+ focal_loss += focal_loss_
+
+ pull_loss = 0
+ push_loss = 0
+
+ ones = fluid.layers.assign(np.array([1], dtype='float32'))
+ tl_tags = [
+ mask_feat(tl_tag, gt_tl_ind, self.train_batch_size)
+ for tl_tag in self.tl_tags
+ ]
+ br_tags = [
+ mask_feat(br_tag, gt_br_ind, self.train_batch_size)
+ for br_tag in self.br_tags
+ ]
+
+ pull_loss, push_loss = 0, 0
+
+ for tl_tag, br_tag in zip(tl_tags, br_tags):
+ pull, push = self.ae_loss(tl_tag, br_tag, gt_masks)
+ pull_loss += pull
+ push_loss += push
+
+ tl_offs = [
+ mask_feat(tl_off, gt_tl_ind, self.train_batch_size)
+ for tl_off in self.tl_offs
+ ]
+ br_offs = [
+ mask_feat(br_off, gt_br_ind, self.train_batch_size)
+ for br_off in self.br_offs
+ ]
+
+ off_loss = 0
+ for tl_off, br_off in zip(tl_offs, br_offs):
+ off_loss += self.off_loss(tl_off, gt_tl_off, gt_masks)
+ off_loss += self.off_loss(br_off, gt_br_off, gt_masks)
+
+ pull_loss = self.pull_weight * pull_loss
+ push_loss = self.push_weight * push_loss
+
+ loss = (
+ focal_loss + pull_loss + push_loss + off_loss) / len(self.tl_heats)
+ return {'loss': loss}
+
+ def get_prediction(self, input):
+ ind = self.stack - 1
+ tl_modules = corner_pool(
+ input,
+ 256,
+ cornerpool_lib.top_pool,
+ cornerpool_lib.left_pool,
+ is_test=True,
+ name='tl_modules_' + str(ind))
+ br_modules = corner_pool(
+ input,
+ 256,
+ cornerpool_lib.bottom_pool,
+ cornerpool_lib.right_pool,
+ is_test=True,
+ name='br_modules_' + str(ind))
+
+ tl_heat = self.pred_mod(
+ tl_modules, self.num_classes, name='tl_heats_' + str(ind))
+ br_heat = self.pred_mod(
+ br_modules, self.num_classes, name='br_heats_' + str(ind))
+
+ tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind))
+ br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind))
+
+ tl_off = self.pred_mod(tl_modules, 2, name='tl_offs_' + str(ind))
+ br_off = self.pred_mod(br_modules, 2, name='br_offs_' + str(ind))
+
+ return decode(tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off,
+ self.ae_threshold, self.num_dets, self.K,
+ self.test_batch_size)
diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py
index 652a383123b8ec51a5ea8b7bc8c0c2b6085da36d..566d9a612c8e71912dad627705a860a7405c8e87 100644
--- a/ppdet/modeling/architectures/__init__.py
+++ b/ppdet/modeling/architectures/__init__.py
@@ -25,6 +25,7 @@ from . import retinanet
from . import blazeface
from . import faceboxes
from . import fcos
+from . import cornernet_squeeze
from .faster_rcnn import *
from .mask_rcnn import *
@@ -37,3 +38,4 @@ from .retinanet import *
from .blazeface import *
from .faceboxes import *
from .fcos import *
+from .cornernet_squeeze import *
diff --git a/ppdet/modeling/architectures/cascade_mask_rcnn.py b/ppdet/modeling/architectures/cascade_mask_rcnn.py
index 30180ddacfe1742a82fa147f9c9e727a36135e79..e0bba61a4c5c60e67db178f01fbdb21ce19a9aae 100644
--- a/ppdet/modeling/architectures/cascade_mask_rcnn.py
+++ b/ppdet/modeling/architectures/cascade_mask_rcnn.py
@@ -408,7 +408,7 @@ class CascadeMaskRCNN(object):
box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox']
for key in box_fields:
inputs_def[key] = {
- 'shape': [6],
+ 'shape': [None, 6],
'dtype': 'float32',
'lod_level': 1
}
diff --git a/ppdet/modeling/architectures/cornernet_squeeze.py b/ppdet/modeling/architectures/cornernet_squeeze.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcba613783195dfb3ec50711c5715d979a693c19
--- /dev/null
+++ b/ppdet/modeling/architectures/cornernet_squeeze.py
@@ -0,0 +1,138 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import OrderedDict
+
+from paddle import fluid
+
+from ppdet.core.workspace import register
+import numpy as np
+
+__all__ = ['CornerNetSqueeze']
+
+
+def rescale_bboxes(bboxes, ratios, borders):
+ x1, y1, x2, y2 = fluid.layers.split(bboxes, 4)
+ x1 = x1 / ratios[:, 1] - borders[:, 2]
+ x2 = x2 / ratios[:, 1] - borders[:, 2]
+ y1 = y1 / ratios[:, 0] - borders[:, 0]
+ y2 = y2 / ratios[:, 0] - borders[:, 0]
+ return fluid.layers.concat([x1, y1, x2, y2], axis=2)
+
+
+@register
+class CornerNetSqueeze(object):
+ """
+ """
+ __category__ = 'architecture'
+ __inject__ = ['backbone', 'corner_head', 'fpn']
+ __shared__ = ['num_classes']
+
+ def __init__(self,
+ backbone,
+ corner_head='CornerHead',
+ num_classes=80,
+ fpn=None):
+ super(CornerNetSqueeze, self).__init__()
+ self.backbone = backbone
+ self.corner_head = corner_head
+ self.num_classes = num_classes
+ self.fpn = fpn
+
+ def build(self, feed_vars, mode='train'):
+ im = feed_vars['image']
+ body_feats = self.backbone(im)
+ if self.fpn is not None:
+ body_feats, _ = self.fpn.get_output(body_feats)
+ body_feats = [body_feats.values()[-1]]
+ if mode == 'train':
+ target_vars = [
+ 'tl_heatmaps', 'br_heatmaps', 'tag_masks', 'tl_regrs',
+ 'br_regrs', 'tl_tags', 'br_tags'
+ ]
+ target = {key: feed_vars[key] for key in target_vars}
+ self.corner_head.get_output(body_feats)
+ loss = self.corner_head.get_loss(target)
+ return loss
+
+ elif mode == 'test':
+ ratios = feed_vars['ratios']
+ borders = feed_vars['borders']
+ bboxes, scores, tl_scores, br_scores, clses = self.corner_head.get_prediction(
+ body_feats[-1])
+ bboxes = rescale_bboxes(bboxes, ratios, borders)
+ detections = fluid.layers.concat([clses, scores, bboxes], axis=2)
+
+ detections = detections[0]
+ return {'bbox': detections}
+
+ def _inputs_def(self, image_shape, output_size, max_tag_len):
+ im_shape = [None] + image_shape
+ C = self.num_classes
+ # yapf: disable
+ inputs_def = {
+ 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
+ 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
+ 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
+ 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
+ 'ratios': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 0},
+ 'borders': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 0},
+ 'tl_heatmaps': {'shape': [None, C, output_size, output_size], 'dtype': 'float32', 'lod_level': 0},
+ 'br_heatmaps': {'shape': [None, C, output_size, output_size], 'dtype': 'float32', 'lod_level': 0},
+ 'tl_regrs': {'shape': [None, max_tag_len, 2], 'dtype': 'float32', 'lod_level': 0},
+ 'br_regrs': {'shape': [None, max_tag_len, 2], 'dtype': 'float32', 'lod_level': 0},
+ 'tl_tags': {'shape': [None, max_tag_len], 'dtype': 'int64', 'lod_level': 0},
+ 'br_tags': {'shape': [None, max_tag_len], 'dtype': 'int64', 'lod_level': 0},
+ 'tag_masks': {'shape': [None, max_tag_len], 'dtype': 'int32', 'lod_level': 0},
+ }
+ # yapf: enable
+ return inputs_def
+
+ def build_inputs(
+ self,
+ image_shape=[3, None, None],
+ fields=[
+ 'image', 'im_id', 'gt_box', 'gt_class', 'tl_heatmaps',
+ 'br_heatmaps', 'tl_regrs', 'br_regrs', 'tl_tags', 'br_tags',
+ 'tag_masks'
+ ], # for train
+ output_size=64,
+ max_tag_len=128,
+ use_dataloader=True,
+ iterable=False):
+ inputs_def = self._inputs_def(image_shape, output_size, max_tag_len)
+ feed_vars = OrderedDict([(key, fluid.data(
+ name=key,
+ shape=inputs_def[key]['shape'],
+ dtype=inputs_def[key]['dtype'],
+ lod_level=inputs_def[key]['lod_level'])) for key in fields])
+ loader = fluid.io.DataLoader.from_generator(
+ feed_list=list(feed_vars.values()),
+ capacity=64,
+ use_double_buffer=True,
+ iterable=iterable) if use_dataloader else None
+ return feed_vars, loader
+
+ def train(self, feed_vars):
+ return self.build(feed_vars, mode='train')
+
+ def eval(self, feed_vars):
+ return self.build(feed_vars, mode='test')
+
+ def test(self, feed_vars):
+ return self.build(feed_vars, mode='test')
diff --git a/ppdet/modeling/architectures/mask_rcnn.py b/ppdet/modeling/architectures/mask_rcnn.py
index 1f3f0104a14e023ba44c50e514eabe4f4c39b92c..f7eb28fdad4cd7c3b5c8f62ddf72c154fff9f862 100644
--- a/ppdet/modeling/architectures/mask_rcnn.py
+++ b/ppdet/modeling/architectures/mask_rcnn.py
@@ -311,7 +311,7 @@ class MaskRCNN(object):
box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox']
for key in box_fields:
inputs_def[key] = {
- 'shape': [6],
+ 'shape': [None, 6],
'dtype': 'float32',
'lod_level': 1
}
diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py
index 2c31e792dc88c505c0e4a3e615be2ec390357c58..35fc91dcc89f4092d8bc113c574f0c9f44768062 100644
--- a/ppdet/modeling/backbones/__init__.py
+++ b/ppdet/modeling/backbones/__init__.py
@@ -29,6 +29,7 @@ from . import res2net
from . import hrnet
from . import hrfpn
from . import bfp
+from . import hourglass
from .resnet import *
from .resnext import *
@@ -45,3 +46,4 @@ from .res2net import *
from .hrnet import *
from .hrfpn import *
from .bfp import *
+from .hourglass import *
diff --git a/ppdet/modeling/backbones/hourglass.py b/ppdet/modeling/backbones/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..b38f79bb408af2c1e0fb47e5faef908130e04e8c
--- /dev/null
+++ b/ppdet/modeling/backbones/hourglass.py
@@ -0,0 +1,275 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import fluid
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import Uniform
+
+import functools
+from ppdet.core.workspace import register
+from .resnet import ResNet
+import math
+
+__all__ = ['Hourglass']
+
+
+def kaiming_init(input, filter_size):
+ fan_in = input.shape[1]
+ std = (1.0 / (fan_in * filter_size * filter_size))**0.5
+ return Uniform(0. - std, std)
+
+
+def _conv_norm(x,
+ k,
+ out_dim,
+ stride=1,
+ pad=0,
+ groups=None,
+ with_bn=True,
+ bn_act=None,
+ ind=None,
+ name=None):
+ conv_name = "_conv" if ind is None else "_conv" + str(ind)
+ bn_name = "_bn" if ind is None else "_bn" + str(ind)
+
+ conv = fluid.layers.conv2d(
+ input=x,
+ filter_size=k,
+ num_filters=out_dim,
+ stride=stride,
+ padding=pad,
+ groups=groups,
+ param_attr=ParamAttr(
+ name=name + conv_name + "_weight", initializer=kaiming_init(x, k)),
+ bias_attr=ParamAttr(
+ name=name + conv_name + "_bias", initializer=kaiming_init(x, k))
+ if not with_bn else False,
+ name=name + '_output')
+ if with_bn:
+ pattr = ParamAttr(name=name + bn_name + '_weight')
+ battr = ParamAttr(name=name + bn_name + '_bias')
+ out = fluid.layers.batch_norm(
+ input=conv,
+ act=bn_act,
+ name=name + '_bn_output',
+ param_attr=pattr,
+ bias_attr=battr,
+ moving_mean_name=name + bn_name + '_running_mean',
+ moving_variance_name=name + bn_name +
+ '_running_var') if with_bn else conv
+ else:
+ out = fluid.layers.relu(conv)
+ return out
+
+
+def residual_block(x, out_dim, k=3, stride=1, name=None):
+ p = (k - 1) // 2
+ conv1 = _conv_norm(
+ x, k, out_dim, pad=p, stride=stride, bn_act='relu', ind=1, name=name)
+ conv2 = _conv_norm(conv1, k, out_dim, pad=p, ind=2, name=name)
+
+ skip = _conv_norm(
+ x, 1, out_dim, stride=stride,
+ name=name + '_skip') if stride != 1 or x.shape[1] != out_dim else x
+ return fluid.layers.elementwise_add(
+ x=skip, y=conv2, act='relu', name=name + "_add")
+
+
+def fire_block(x, out_dim, sr=2, stride=1, name=None):
+ conv1 = _conv_norm(x, 1, out_dim // sr, ind=1, name=name)
+ conv_1x1 = fluid.layers.conv2d(
+ conv1,
+ filter_size=1,
+ num_filters=out_dim // 2,
+ stride=stride,
+ param_attr=ParamAttr(
+ name=name + "_conv_1x1_weight", initializer=kaiming_init(conv1, 1)),
+ bias_attr=False,
+ name=name + '_conv_1x1')
+ conv_3x3 = fluid.layers.conv2d(
+ conv1,
+ filter_size=3,
+ num_filters=out_dim // 2,
+ stride=stride,
+ padding=1,
+ groups=out_dim // sr,
+ param_attr=ParamAttr(
+ name=name + "_conv_3x3_weight", initializer=kaiming_init(conv1, 3)),
+ bias_attr=False,
+ name=name + '_conv_3x3',
+ use_cudnn=False)
+ conv2 = fluid.layers.concat(
+ [conv_1x1, conv_3x3], axis=1, name=name + '_conv2')
+ pattr = ParamAttr(name=name + '_bn2_weight')
+ battr = ParamAttr(name=name + '_bn2_bias')
+
+ bn2 = fluid.layers.batch_norm(
+ input=conv2,
+ name=name + '_bn2',
+ param_attr=pattr,
+ bias_attr=battr,
+ moving_mean_name=name + '_bn2_running_mean',
+ moving_variance_name=name + '_bn2_running_var')
+
+ if stride == 1 and x.shape[1] == out_dim:
+ return fluid.layers.elementwise_add(
+ x=bn2, y=x, act='relu', name=name + "_add_relu")
+ else:
+ return fluid.layers.relu(bn2, name="_relu")
+
+
+def make_layer(x, in_dim, out_dim, modules, block, name=None):
+ layers = block(x, out_dim, name=name + '_0')
+ for i in range(1, modules):
+ layers = block(layers, out_dim, name=name + '_' + str(i))
+ return layers
+
+
+def make_hg_layer(x, in_dim, out_dim, modules, block, name=None):
+ layers = block(x, out_dim, stride=2, name=name + '_0')
+ for i in range(1, modules):
+ layers = block(layers, out_dim, name=name + '_' + str(i))
+ return layers
+
+
+def make_layer_revr(x, in_dim, out_dim, modules, block, name=None):
+ for i in range(modules - 1):
+ x = block(x, in_dim, name=name + '_' + str(i))
+ layers = block(x, out_dim, name=name + '_' + str(modules - 1))
+ return layers
+
+
+def make_unpool_layer(x, dim, name=None):
+ pattr = ParamAttr(name=name + '_weight', initializer=kaiming_init(x, 4))
+ battr = ParamAttr(name=name + '_bias', initializer=kaiming_init(x, 4))
+ layer = fluid.layers.conv2d_transpose(
+ input=x,
+ num_filters=dim,
+ filter_size=4,
+ stride=2,
+ padding=1,
+ param_attr=pattr,
+ bias_attr=battr)
+ return layer
+
+
+@register
+class Hourglass(object):
+ """
+ Hourglass Network, see https://arxiv.org/abs/1603.06937
+ Args:
+ stack (int): stack of hourglass, 2 by default
+ dims (list): dims of each level in hg_module
+ modules (list): num of modules in each level
+ """
+ __shared__ = ['stack']
+
+ def __init__(self,
+ stack=2,
+ dims=[256, 256, 384, 384, 512],
+ modules=[2, 2, 2, 2, 4],
+ block_name='fire'):
+ super(Hourglass, self).__init__()
+ self.stack = stack
+ assert len(dims) == len(modules), \
+ "Expected len of dims equal to len of modules, Receiced len of "\
+ "dims: {}, len of modules: {}".format(len(dims), len(modules))
+ self.dims = dims
+ self.modules = modules
+ self.num_level = len(dims) - 1
+ block_dict = {'fire': fire_block}
+ self.block = block_dict[block_name]
+
+ def __call__(self, input, name='hg'):
+ inter = self.pre(input, name + '_pre')
+ cnvs = []
+ for ind in range(self.stack):
+ hg = self.hg_module(
+ inter,
+ self.num_level,
+ self.dims,
+ self.modules,
+ name=name + '_hgs_' + str(ind))
+ cnv = _conv_norm(
+ hg,
+ 3,
+ 256,
+ bn_act='relu',
+ pad=1,
+ name=name + '_cnvs_' + str(ind))
+ cnvs.append(cnv)
+
+ if ind < self.stack - 1:
+ inter = _conv_norm(
+ inter, 1, 256, name=name + '_inters__' +
+ str(ind)) + _conv_norm(
+ cnv, 1, 256, name=name + '_cnvs__' + str(ind))
+ inter = fluid.layers.relu(inter)
+ inter = residual_block(
+ inter, 256, name=name + '_inters_' + str(ind))
+ return cnvs
+
+ def pre(self, x, name=None):
+ conv = _conv_norm(
+ x, 7, 128, stride=2, pad=3, bn_act='relu', name=name + '_0')
+ res1 = residual_block(conv, 256, stride=2, name=name + '_1')
+ res2 = residual_block(res1, 256, stride=2, name=name + '_2')
+ return res2
+
+ def hg_module(self,
+ x,
+ n=4,
+ dims=[256, 256, 384, 384, 512],
+ modules=[2, 2, 2, 2, 4],
+ make_up_layer=make_layer,
+ make_hg_layer=make_hg_layer,
+ make_low_layer=make_layer,
+ make_hg_layer_revr=make_layer_revr,
+ make_unpool_layer=make_unpool_layer,
+ name=None):
+ curr_mod = modules[0]
+ next_mod = modules[1]
+ curr_dim = dims[0]
+ next_dim = dims[1]
+ up1 = make_up_layer(
+ x, curr_dim, curr_dim, curr_mod, self.block, name=name + '_up1')
+ max1 = x
+ low1 = make_hg_layer(
+ max1, curr_dim, next_dim, curr_mod, self.block, name=name + '_low1')
+ low2 = self.hg_module(
+ low1,
+ n - 1,
+ dims[1:],
+ modules[1:],
+ make_up_layer=make_up_layer,
+ make_hg_layer=make_hg_layer,
+ make_low_layer=make_low_layer,
+ make_hg_layer_revr=make_hg_layer_revr,
+ make_unpool_layer=make_unpool_layer,
+ name=name + '_low2') if n > 1 else make_low_layer(
+ low1,
+ next_dim,
+ next_dim,
+ next_mod,
+ self.block,
+ name=name + '_low2')
+ low3 = make_hg_layer_revr(
+ low2, next_dim, curr_dim, curr_mod, self.block, name=name + '_low3')
+ up2 = make_unpool_layer(low3, curr_dim, name=name + '_up2')
+ merg = fluid.layers.elementwise_add(x=up1, y=up2, name=name + '_merg')
+ return merg
diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py
index 28abd477c80ed19e8e5320501d84e34e2f806b14..ca861cf183e3704c4cd437019159f022d1519359 100644
--- a/ppdet/modeling/ops.py
+++ b/ppdet/modeling/ops.py
@@ -378,7 +378,7 @@ class MultiClassSoftNMS(object):
fluid.default_main_program(),
name='softnms_pred_result',
dtype='float32',
- shape=[6],
+ shape=[-1, 6],
lod_level=1)
fluid.layers.py_func(
func=_soft_nms, x=[bboxes, scores], out=pred_result)
diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py
index e7a6eb8ab0e8729e5187212671b7fd151627b7e9..1b91bfa343d9dc11bad2e64c0af29c69f6183ece 100644
--- a/ppdet/optimizer.py
+++ b/ppdet/optimizer.py
@@ -148,10 +148,12 @@ class OptimizerBuilder():
self.optimizer = optimizer
def __call__(self, learning_rate):
- reg_type = self.regularizer['type'] + 'Decay'
- reg_factor = self.regularizer['factor']
- regularization = getattr(regularizer, reg_type)(reg_factor)
-
+ if self.regularizer:
+ reg_type = self.regularizer['type'] + 'Decay'
+ reg_factor = self.regularizer['factor']
+ regularization = getattr(regularizer, reg_type)(reg_factor)
+ else:
+ regularization = None
optim_args = self.optimizer.copy()
optim_type = optim_args['type']
del optim_args['type']
diff --git a/ppdet/utils/coco_eval.py b/ppdet/utils/coco_eval.py
index d2ceeacb42e6575d7e7dbbeeb2b4a59f86493634..8ccae76ead693ac3b0b6bcbc99694abf8cd4a514 100644
--- a/ppdet/utils/coco_eval.py
+++ b/ppdet/utils/coco_eval.py
@@ -230,9 +230,10 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
xywh_res = []
for t in results:
bboxes = t['bbox'][0]
+ if len(t['bbox'][1]) == 0: continue
lengths = t['bbox'][1][0]
im_ids = np.array(t['im_id'][0]).flatten()
- if bboxes.shape == (1, 1) or bboxes is None:
+ if bboxes.shape == (1, 1) or bboxes is None or len(bboxes) == 0:
continue
k = 0
diff --git a/ppdet/utils/eval_utils.py b/ppdet/utils/eval_utils.py
index 717053ad34b6eb3ae36a51327fe52597af67667a..ef1d537b3dc626406ca069ae34b06e2d3616e668 100644
--- a/ppdet/utils/eval_utils.py
+++ b/ppdet/utils/eval_utils.py
@@ -135,7 +135,8 @@ def eval_run(exe,
mask_multi_scale_test = multi_scale_test and 'Mask' in cfg.architecture
if multi_scale_test:
- post_res = mstest_box_post_process(res, cfg)
+ post_res = mstest_box_post_process(res, multi_scale_test,
+ cfg.num_classes)
res.update(post_res)
if mask_multi_scale_test:
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
@@ -156,10 +157,16 @@ def eval_run(exe,
if 'mask' in res:
from ppdet.utils.post_process import mask_encode
res['mask'] = mask_encode(res, resolution)
+ post_config = getattr(cfg, 'PostProcess', None)
+ if 'Corner' in cfg.architecture and post_config is not None:
+ from ppdet.utils.post_process import corner_post_process
+ corner_post_process(res, post_config, cfg.num_classes)
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
iter_id += 1
+ if len(res['bbox'][1]) == 0:
+ has_bbox = False
images_num += len(res['bbox'][1][0]) if has_bbox else 1
except (StopIteration, fluid.core.EOFException):
loader.reset()
diff --git a/ppdet/utils/post_process.py b/ppdet/utils/post_process.py
index c42fcc5a71798f4adc648415b647585514207d7e..cf251998348504332675799bd3d3ccfb342d515c 100644
--- a/ppdet/utils/post_process.py
+++ b/ppdet/utils/post_process.py
@@ -38,7 +38,7 @@ def box_flip(boxes, im_shape):
def nms(dets, thresh):
"""Apply classic DPM-style greedy NMS."""
if dets.shape[0] == 0:
- return []
+ return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
@@ -86,8 +86,40 @@ def nms(dets, thresh):
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
-
- return np.where(suppressed == 0)[0]
+ keep = np.where(suppressed == 0)[0]
+ dets = dets[keep, :]
+ return dets
+
+
+def soft_nms(dets, sigma, thres):
+ dets_final = []
+ while len(dets) > 0:
+ maxpos = np.argmax(dets[:, 0])
+ dets_final.append(dets[maxpos].copy())
+ ts, tx1, ty1, tx2, ty2 = dets[maxpos]
+ scores = dets[:, 0]
+ # force remove bbox at maxpos
+ scores[maxpos] = -1
+ x1 = dets[:, 1]
+ y1 = dets[:, 2]
+ x2 = dets[:, 3]
+ y2 = dets[:, 4]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ xx1 = np.maximum(tx1, x1)
+ yy1 = np.maximum(ty1, y1)
+ xx2 = np.minimum(tx2, x2)
+ yy2 = np.minimum(ty2, y2)
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas + areas[maxpos] - inter)
+ weight = np.exp(-(ovr * ovr) / sigma)
+ scores = scores * weight
+ idx_keep = np.where(scores >= thres)
+ dets[:, 0] = scores
+ dets = dets[idx_keep]
+ dets_final = np.array(dets_final).reshape(-1, 5)
+ return dets_final
def bbox_area(box):
@@ -128,39 +160,49 @@ def box_voting(nms_dets, dets, vote_thresh):
return top_dets
-def get_nms_result(boxes, scores, cfg):
- cls_boxes = [[] for _ in range(cfg.num_classes)]
- for j in range(1, cfg.num_classes):
- inds = np.where(scores[:, j] > cfg.MultiScaleTEST['score_thresh'])[0]
- scores_j = scores[inds, j]
- boxes_j = boxes[inds, j * 4:(j + 1) * 4]
+def get_nms_result(boxes,
+ scores,
+ config,
+ num_classes,
+ background_label=0,
+ labels=None):
+ has_labels = labels is not None
+ cls_boxes = [[] for _ in range(num_classes)]
+ start_idx = 1 if background_label == 0 else 0
+ for j in range(start_idx, num_classes):
+ inds = np.where(labels == j)[0] if has_labels else np.where(
+ scores[:, j] > config['score_thresh'])[0]
+ scores_j = scores[inds] if has_labels else scores[inds, j]
+ boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) *
+ 4]
dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype(
np.float32, copy=False)
- keep = nms(dets_j, cfg.MultiScaleTEST['nms_thresh'])
- nms_dets = dets_j[keep, :]
- if cfg.MultiScaleTEST['enable_voting']:
- nms_dets = box_voting(nms_dets, dets_j,
- cfg.MultiScaleTEST['vote_thresh'])
+ if config.get('use_soft_nms', False):
+ nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh'])
+ else:
+ nms_dets = nms(dets_j, config['nms_thresh'])
+ if config.get('enable_voting', False):
+ nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh'])
#add labels
- label = np.array([j for _ in range(len(keep))])
+ label = np.array([j for _ in range(len(nms_dets))])
nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
np.float32, copy=False)
cls_boxes[j] = nms_dets
# Limit to max_per_image detections **over all classes**
image_scores = np.hstack(
- [cls_boxes[j][:, 1] for j in range(1, cfg.num_classes)])
- if len(image_scores) > cfg.MultiScaleTEST['detections_per_im']:
- image_thresh = np.sort(image_scores)[-cfg.MultiScaleTEST[
- 'detections_per_im']]
- for j in range(1, cfg.num_classes):
+ [cls_boxes[j][:, 1] for j in range(start_idx, num_classes)])
+ if len(image_scores) > config['detections_per_im']:
+ image_thresh = np.sort(image_scores)[-config['detections_per_im']]
+ for j in range(start_idx, num_classes):
keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
cls_boxes[j] = cls_boxes[j][keep, :]
- im_results = np.vstack([cls_boxes[j] for j in range(1, cfg.num_classes)])
+ im_results = np.vstack(
+ [cls_boxes[j] for j in range(start_idx, num_classes)])
return im_results
-def mstest_box_post_process(result, cfg):
+def mstest_box_post_process(result, config, num_classes):
"""
Multi-scale Test
Only available for batch_size=1 now.
@@ -173,7 +215,7 @@ def mstest_box_post_process(result, cfg):
for k in result.keys():
if 'bbox' in k:
boxes = result[k][0]
- boxes = np.reshape(boxes, (-1, 4 * cfg.num_classes))
+ boxes = np.reshape(boxes, (-1, 4 * num_classes))
scores = result['score' + k[4:]][0]
if 'flip' in k:
boxes = box_flip(boxes, im_shape)
@@ -183,7 +225,7 @@ def mstest_box_post_process(result, cfg):
ms_boxes = np.concatenate(ms_boxes)
ms_scores = np.concatenate(ms_scores)
- bbox_pred = get_nms_result(ms_boxes, ms_scores, cfg)
+ bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes)
post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])})
if use_flip:
bbox = bbox_pred[:, 2:]
@@ -271,3 +313,15 @@ def mask_encode(results, resolution, thresh_binarize=0.5):
im_mask[:, :, np.newaxis], order='F'))[0]
segms.append(segm)
return segms
+
+
+def corner_post_process(results, config, num_classes):
+ detections = results['bbox'][0]
+ keep_inds = (detections[:, 1] > -1)
+ detections = detections[keep_inds]
+ labels = detections[:, 0]
+ scores = detections[:, 1]
+ boxes = detections[:, 2:6]
+ cls_boxes = get_nms_result(
+ boxes, scores, config, num_classes, background_label=-1, labels=labels)
+ results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])})