diff --git a/PaddleCV/rrpn/README.md b/PaddleCV/rrpn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d9e6fc34e3c4f4cc0189b388ac9f73891afce100
--- /dev/null
+++ b/PaddleCV/rrpn/README.md
@@ -0,0 +1,168 @@
+# RRPN 旋转物体检测
+
+---
+## 内容
+
+- [安装](#安装)
+- [简介](#简介)
+- [数据准备](#数据准备)
+- [模型训练](#模型训练)
+- [模型评估](#模型评估)
+- [模型推断及可视化](#模型推断及可视化)
+
+## 安装
+
+在当前目录下运行样例代码需要PadddlePaddle Fluid的develop或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/)中的说明来更新PaddlePaddle。
+
+
+## 简介
+RRPN是在Faster RCNN基础上拓展出的两阶段目标检测器,可用于文字检测和旋转物体检测。通过对图像生成候选区域,提取特征,判别特征类别并修正候选框位置。
+
+[RRPN](https://arxiv.org/abs/1703.01086) 整体网络可以分为4个主要内容:
+
+1. 基础卷积层。作为一种卷积神经网络目标检测方法,RRPN首先使用一组基础的卷积网络提取图像的特征图。特征图被后续RPN层和全连接层共享。本示例采用[ResNet-50](https://arxiv.org/abs/1512.03385)作为基础卷积层。
+2. 区域生成网络(RPN)。RPN网络用于生成候选区域(proposals)。该层通过一组固定的尺寸、比例和角度得到一组带方向锚点(anchors), 通过softmax判断旋转的锚点属于前景或者背景,再利用区域回归修正锚点从而获得精确的候选区域。
+3. Rotated RoI Align。该层收集输入的特征图和带方向的候选区域,将带方向的候选区域映射到特征图中进行并池化为统一大小的区域特征图,送入全连接层判定目标类别。
+4. 检测层。利用区域特征图计算候选区域的类别,同时再次通过区域回归获得检测框最终的精确位置。
+
+### 编译自定义OP
+
+自定义OP编译方式如下:
+
+ 进入 `ext_op/src` 目录,执行编译脚本
+ ```
+ cd ext_op/src
+ sh make.sh ${cuda_path} ${cudnn_path} ${nccl_path}
+ '''
+ 其中${cuda_path}、$cudnn_path}和{nccl_path}分别为cuda、cudnn、nccl的安装路径,需通过命令行进行指定
+ 成功编译后,`ext_op/src` 目录下将会生成 `rrpn_lib.so`
+
+## 数据准备
+### 公开数据集
+在[ICDAR2015数据集](https://rrc.cvc.uab.es/?ch=4&com=downloads)上进行训练,数据集需进入官网进行注册后方可下载。
+
+数据目录结构如下:
+
+```
+dataset/icdar2015/
+├── ch4_training_images
+│ ├── img_143.jpg
+│ ├── img_144.jpg
+| ...
+├── ch4_training_localization_transcription_gt
+│ ├── gt_img_143.txt
+│ ├── gt_img_144.txt
+| ...
+├── ch4_test_images
+│ ├── img_111.jpg
+│ ├── img_112.jpg
+| ...
+├── ch4_test_localization_transcription_gt
+│ ├── img_111.jpg
+│ ├── img_112.jpg
+| ...
+```
+### 自定义数据
+原始的RRPN只提供了二分类,若要使用自己数据进行训练多分类,需在utility.py中将dataset改为icdar2017,然后将class_num改为需求类别数,其中0为背景类。
+
+训练自定义数据时,数据目录结构和ICDAR2015一致,标注数据格式如下:
+```
+x1, y1, x2, y2, x3, y3, x4, y4, class_name
+x1, y1, x2, y2, x3, y3, x4, y4, class_name
+```
+
+## 模型训练
+
+**下载预训练模型:** 本示例提供Resnet-50预训练模型,采用如下命令下载预训练模型:
+
+ sh ./pretrained/download.sh
+
+
+通过初始化`pretrained_model` 加载预训练模型。同时在参数微调时也采用该设置加载已训练模型。
+请在训练前确认预训练模型下载与加载正确,否则训练过程中损失可能会出现NAN。
+
+
+- RRPN
+
+ ```
+ python train.py \
+ --model_save_dir=output/ \
+ --pretrained_model=${path_to_pretrain_model} \
+ --data_dir=${path_to_data} \
+ ```
+
+
+
+ - 通过设置export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
+
+ - 可选参数见:
+
+ python train.py --help
+
+**数据读取器说明:** 数据读取器定义在reader.py中。所有图像将短边等比例缩放至`scales`,若长边大于`max_size`, 则再次将长边等比例缩放至`max_size`。在训练阶段,对图像采用随机旋转。
+
+**模型设置:**
+
+* 使用RotatedRoIAlign方法。
+* 训练过程pre\_nms=12000, post\_nms=2000,测试过程pre\_nms=6000, post\_nms=1000。nms阈值为0.7。
+* RPN网络得到labels的过程中,fg\_fraction=0.25,fg\_thresh=0.5,bg\_thresh_hi=0.5,bg\_thresh\_lo=0.0
+* RPN选择anchor时,rpn\_fg\_fraction=0.5,rpn\_positive\_overlap=0.7,rpn\_negative\_overlap=0.3
+
+
+**训练策略:**
+* 默认配置采用8卡,每卡batch size=1
+* 采用momentum优化算法训练,momentum=0.9。
+* 权重衰减系数为0.02,前500轮学习率从0.00333线性增加至0.01。在6250,12500轮时使用0.1,0.01乘子进行学习率衰减,最大训练17500轮。训练最大轮数和学习率策略可以在config.py中对max_iter和lr_steps进行设置。
+* 非基础卷积层卷积bias学习率为整体学习率2倍。
+* 基础卷积层中,affine_layers参数不更新,res2层参数不更新。
+
+## 模型评估
+
+模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[ICDAR2015官方评估](https://rrc.cvc.uab.es/?com=contestant)
+
+`eval.py`是评估模块的主要执行程序,调用示例如下:
+
+- RRPN
+
+ ```
+ python eval.py \
+ --dataset=icdar2015 \
+ --pretrained_model=${path_to_trained_model}
+ ```
+
+ - 通过设置`--pretrained_model=${path_to_trained_model}`指定训练好的模型,注意不是初始化的模型。
+ - 通过设置`export CUDA\_VISIBLE\_DEVICES=0`指定单卡GPU评估。
+
+
+下表为模型评估结果:
+
+RRPN
+
+| 模型 | 批量大小 | 迭代次数 | F1 |
+| :--------------- | :------------: | :------------------: |------: |
+| [RRPN](https://paddleseg.bj.bcebos.com/deploy/temp/model_final.tar) |8 | 17500 | 0.8048 |
+
+
+
+
+
+
+## 模型推断及可视化
+
+模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下:
+
+```
+python infer.py \
+ --pretrained_model=${path_to_trained_model} \
+ --image_path=dataset/icdar2015 \
+ --draw_threshold=0.6
+```
+
+注意,请正确设置模型路径`${path_to_trained_model}`和预测图片路径。默认使用GPU设备,也可通过设置`--use_gpu=False`使用CPU设备。可通过设置`draw_threshold`调节得分阈值控制检测框的个数。
+
+下图为模型可视化预测结果:
+
+
+
+RRPN 预测可视化
+
diff --git a/PaddleCV/rrpn/__init__.py b/PaddleCV/rrpn/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PaddleCV/rrpn/checkpoint.py b/PaddleCV/rrpn/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7062199e1b3a0fb0bd5619b503559a335e54b2e9
--- /dev/null
+++ b/PaddleCV/rrpn/checkpoint.py
@@ -0,0 +1,186 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import errno
+import os
+import shutil
+import time
+import numpy as np
+import re
+import paddle.fluid as fluid
+import logging
+logger = logging.getLogger(__name__)
+
+
+def load_params(exe, prog, path):
+ """
+ Load model from the given path.
+ Args:
+ exe (fluid.Executor): The fluid.Executor object.
+ prog (fluid.Program): load weight to which Program object.
+ path (string): URL string or loca model path.
+ """
+
+ if not os.path.exists(path):
+ raise ValueError("Model pretrain path {} does not "
+ "exists.".format(path))
+
+ logger.info('Loading parameters from {}...'.format(path))
+
+ def _if_exist(var):
+ param_exist = os.path.exists(os.path.join(path, var.name))
+ do_load = param_exist
+ if do_load:
+ logger.debug('load weight {}'.format(var.name))
+ return do_load
+
+ fluid.io.load_vars(exe, path, prog, predicate=_if_exist)
+
+
+def save(exe, prog, path):
+ """
+ Load model from the given path.
+ Args:
+ exe (fluid.Executor): The fluid.Executor object.
+ prog (fluid.Program): save weight from which Program object.
+ path (string): the path to save model.
+ """
+ if os.path.isdir(path):
+ shutil.rmtree(path)
+ logger.info('Save model to {}.'.format(path))
+ fluid.io.save_persistables(exe, path, prog)
+
+
+def load_and_fusebn(exe, prog, path):
+ """
+ Fuse params of batch norm to scale and bias.
+ Args:
+ exe (fluid.Executor): The fluid.Executor object.
+ prog (fluid.Program): save weight from which Program object.
+ path (string): the path to save model.
+ """
+ logger.info('Load model and fuse batch norm if have from {}...'.format(
+ path))
+
+ if not os.path.exists(path):
+ raise ValueError("Model path {} does not exists.".format(path))
+
+ def _if_exist(var):
+ b = os.path.exists(os.path.join(path, var.name))
+
+ if b:
+ logger.debug('load weight {}'.format(var.name))
+ return b
+
+ all_vars = list(filter(_if_exist, prog.list_vars()))
+
+ # Since the program uses affine-channel, there is no running mean and var
+ # in the program, here append running mean and var.
+ # NOTE, the params of batch norm should be like:
+ # x_scale
+ # x_offset
+ # x_mean
+ # x_variance
+ # x is any prefix
+ mean_variances = set()
+ bn_vars = []
+
+ bn_in_path = True
+
+ inner_prog = fluid.Program()
+ inner_start_prog = fluid.Program()
+ inner_block = inner_prog.global_block()
+ with fluid.program_guard(inner_prog, inner_start_prog):
+ for block in prog.blocks:
+ ops = list(block.ops)
+ if not bn_in_path:
+ break
+ for op in ops:
+ if op.type == 'affine_channel':
+ # remove 'scale' as prefix
+ scale_name = op.input('Scale')[0] # _scale
+ bias_name = op.input('Bias')[0] # _offset
+ prefix = scale_name[:-5]
+ mean_name = prefix + 'mean'
+ variance_name = prefix + 'variance'
+
+ if not os.path.exists(os.path.join(path, mean_name)):
+ bn_in_path = False
+ break
+ if not os.path.exists(os.path.join(path, variance_name)):
+ bn_in_path = False
+ break
+
+ bias = block.var(bias_name)
+
+ mean_vb = inner_block.create_var(
+ name=mean_name,
+ type=bias.type,
+ shape=bias.shape,
+ dtype=bias.dtype,
+ persistable=True)
+ variance_vb = inner_block.create_var(
+ name=variance_name,
+ type=bias.type,
+ shape=bias.shape,
+ dtype=bias.dtype,
+ persistable=True)
+
+ mean_variances.add(mean_vb)
+ mean_variances.add(variance_vb)
+
+ bn_vars.append(
+ [scale_name, bias_name, mean_name, variance_name])
+
+ if not bn_in_path:
+ fluid.io.load_vars(exe, path, prog, vars=all_vars)
+ logger.warning(
+ "There is no paramters of batch norm in model {}. "
+ "Skip to fuse batch norm. And load paramters done.".format(path))
+ return
+
+ # load running mean and running variance on cpu place into global scope.
+ place = fluid.CPUPlace()
+ exe_cpu = fluid.Executor(place)
+ fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])
+
+ # load params on real place into global scope.
+ fluid.io.load_vars(exe, path, prog, vars=all_vars)
+
+ eps = 1e-5
+ for names in bn_vars:
+ scale_name, bias_name, mean_name, var_name = names
+
+ scale = fluid.global_scope().find_var(scale_name).get_tensor()
+ bias = fluid.global_scope().find_var(bias_name).get_tensor()
+ mean = fluid.global_scope().find_var(mean_name).get_tensor()
+ var = fluid.global_scope().find_var(var_name).get_tensor()
+
+ scale_arr = np.array(scale)
+ bias_arr = np.array(bias)
+ mean_arr = np.array(mean)
+ var_arr = np.array(var)
+
+ bn_std = np.sqrt(np.add(var_arr, eps))
+ new_scale = np.float32(np.divide(scale_arr, bn_std))
+ new_bias = bias_arr - mean_arr * new_scale
+
+ # fuse to scale and bias in affine_channel
+ scale.set(new_scale, exe.place)
+ bias.set(new_bias, exe.place)
diff --git a/PaddleCV/rrpn/config.py b/PaddleCV/rrpn/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..7cfe7cd5b3a19fe23bfd5a15392d509abf3c6da2
--- /dev/null
+++ b/PaddleCV/rrpn/config.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+from edict import AttrDict
+import six
+import numpy as np
+
+_C = AttrDict()
+cfg = _C
+
+#
+# Training options
+#
+_C.TRAIN = AttrDict()
+
+# scales an image's shortest side
+_C.TRAIN.scales = [800]
+
+# max size of longest side
+_C.TRAIN.max_size = 1333
+
+# images per GPU in minibatch
+_C.TRAIN.im_per_batch = 1
+
+# roi minibatch size per image
+_C.TRAIN.batch_size_per_im = 256
+
+# target fraction of foreground roi minibatch
+_C.TRAIN.fg_fractrion = 0.25
+
+# overlap threshold for a foreground roi
+_C.TRAIN.fg_thresh = 0.5
+
+# overlap threshold for a background roi
+_C.TRAIN.bg_thresh_hi = 0.5
+_C.TRAIN.bg_thresh_lo = 0.0
+
+# If False, only resize image and not pad, image shape is different between
+# GPUs in one mini-batch. If True, image shape is the same in one mini-batch.
+_C.TRAIN.padding_minibatch = False
+
+# Snapshot period
+_C.TRAIN.snapshot_iter = 1000
+
+# number of RPN proposals to keep before NMS
+_C.TRAIN.rpn_pre_nms_top_n = 12000
+
+# number of RPN proposals to keep after NMS
+_C.TRAIN.rpn_post_nms_top_n = 2000
+
+# NMS threshold used on RPN proposals
+_C.TRAIN.rpn_nms_thresh = 0.7
+
+# min size in RPN proposals
+_C.TRAIN.rpn_min_size = 0.0
+
+# eta for adaptive NMS in RPN
+_C.TRAIN.rpn_eta = 1.0
+
+# number of RPN examples per image
+_C.TRAIN.rpn_batch_size_per_im = 256
+
+# remove anchors out of the image
+_C.TRAIN.rpn_straddle_thresh = 0.
+
+# target fraction of foreground examples pre RPN minibatch
+_C.TRAIN.rpn_fg_fraction = 0.5
+
+# min overlap between anchor and gt box to be a positive examples
+_C.TRAIN.rpn_positive_overlap = 0.7
+
+# max overlap between anchor and gt box to be a negative examples
+_C.TRAIN.rpn_negative_overlap = 0.3
+
+# stopgrad at a specified stage
+_C.TRAIN.freeze_at = 2
+
+# min area of ground truth box
+_C.TRAIN.gt_min_area = -1
+
+#
+# Inference options
+#
+_C.TEST = AttrDict()
+
+# scales an image's shortest side
+_C.TEST.scales = [800]
+
+# max size of longest side
+_C.TEST.max_size = 1333
+
+# eta for adaptive NMS in RPN
+_C.TEST.rpn_eta = 1.0
+
+# min score threshold to infer
+_C.TEST.score_thresh = 0.01
+
+# overlap threshold used for NMS
+_C.TEST.nms_thresh = 0.3
+
+# number of RPN proposals to keep before NMS
+_C.TEST.rpn_pre_nms_top_n = 6000
+
+# number of RPN proposals to keep after NMS
+_C.TEST.rpn_post_nms_top_n = 1000
+
+# min size in RPN proposals
+_C.TEST.rpn_min_size = 0.0
+
+# max number of detections
+_C.TEST.detections_per_im = 300
+
+# NMS threshold used on RPN proposals
+_C.TEST.rpn_nms_thresh = 0.7
+
+#
+# Model options
+#
+
+# Whether use mask rcnn head
+_C.MASK_ON = True
+
+# weight for bbox regression targets
+_C.bbox_reg_weights = [10.0, 10.0, 5.0, 5.0, 1.0]
+
+# RPN anchor sizes
+_C.anchor_sizes = [128, 256, 512]
+
+# RPN anchor ratio
+_C.aspect_ratio = [0.2, 0.5, 1.0]
+
+# RPN anchor angle
+_C.anchor_angle = [-30.0, 0.0, 30.0, 60.0, 90.0, 120.0]
+
+# variance of anchors
+_C.variances = [1., 1., 1., 1., 1.]
+
+# stride of feature map
+_C.rpn_stride = [16.0, 16.0]
+
+# pooled width and pooled height
+_C.roi_resolution = 14
+
+# spatial scale
+_C.spatial_scale = 1. / 16.
+
+# resolution to represent rotated roi align
+_C.resolution = 14
+
+#
+# SOLVER options
+#
+
+# derived learning rate the to get the final learning rate.
+_C.learning_rate = 0.01
+
+# maximum number of iterations
+_C.max_iter = 140000
+
+# warm up to learning rate
+_C.warm_up_iter = 500
+_C.start_factor = 1. / 3
+
+# lr steps_with_decay
+_C.lr_steps = [6250, 12500]
+_C.lr_gamma = 0.1
+
+# L2 regularization hyperparameter
+_C.weight_decay = 0.0001
+
+# momentum with SGD
+_C.momentum = 0.9
+
+#
+# ENV options
+#
+
+# support both CPU and GPU
+_C.use_gpu = True
+
+# Whether use parallel
+_C.parallel = True
+
+# Class number
+_C.class_num = 81
+
+# support pyreader
+_C.use_pyreader = True
+_C.TRAIN.min_size = 800
+_C.TRAIN.max_size = 1333
+_C.TEST.min_size = 1000
+# pixel mean values
+_C.pixel_means = [0.485, 0.456, 0.406]
+_C.pixel_std = [0.229, 0.224, 0.225]
+# clip box to prevent overflowing
+_C.bbox_clip = np.log(1000. / 16.)
+
+
+def merge_cfg_from_args(args, mode):
+ """Merge config keys, values in args into the global config."""
+ if mode == 'train':
+ sub_d = _C.TRAIN
+ else:
+ sub_d = _C.TEST
+ for k, v in sorted(six.iteritems(vars(args))):
+ d = _C
+ try:
+ value = eval(v)
+ except:
+ value = v
+ if k in sub_d:
+ sub_d[k] = value
+ else:
+ d[k] = value
diff --git a/PaddleCV/rrpn/data_utils.py b/PaddleCV/rrpn/data_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..339f7cb3b0d35a122ec96a1ee80085d84b168d6b
--- /dev/null
+++ b/PaddleCV/rrpn/data_utils.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Based on:
+# --------------------------------------------------------
+# Detectron
+# Copyright (c) 2017-present, Facebook, Inc.
+# Licensed under the Apache License, Version 2.0;
+# Written by Ross Girshick
+# --------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import cv2
+import numpy as np
+from config import cfg
+import os
+from PIL import Image
+
+
+class DatasetPath(object):
+ def __init__(self, mode, dataset_name):
+ self.mode = mode
+ self.data_dir = dataset_name
+
+ def get_data_dir(self):
+ if self.mode == 'train':
+ return os.path.join(self.data_dir, 'ch4_training_images')
+ elif self.mode == 'val':
+ return os.path.join(self.data_dir, 'ch4_test_images')
+
+ def get_file_list(self):
+ if self.mode == 'train':
+ return os.path.join(self.data_dir,
+ 'ch4_training_localization_transcription_gt')
+ elif self.mode == 'val':
+ return os.path.join(self.data_dir,
+ 'ch4_test_localization_transcription_gt')
+
+
+def get_image_blob(roidb, mode):
+ """Builds an input blob from the images in the roidb at the specified
+ scales.
+ """
+ if mode == 'train' or mode == 'val':
+ with open(roidb['image'], 'rb') as f:
+ data = f.read()
+ data = np.frombuffer(data, dtype='uint8')
+ img = cv2.imdecode(data, 1)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ gt_boxes = roidb['boxes']
+ gt_label = roidb['gt_classes']
+ # resize
+ if mode == 'train':
+ img, im_scale = _resize(img, target_size=800, max_size=1333)
+ need_gt_boxes = gt_boxes.copy()
+ need_gt_boxes[:, :4] *= im_scale
+ img, need_gt_boxes, need_gt_label = _rotation(
+ img, need_gt_boxes, gt_label, prob=1.0, gt_margin=1.4)
+ else:
+ img, im_scale = _resize(img, target_size=1000, max_size=1778)
+ need_gt_boxes = gt_boxes
+ need_gt_label = gt_label
+ img = img.astype(np.float32, copy=False)
+ img = img / 255.0
+ mean = np.array(cfg.pixel_means)[np.newaxis, np.newaxis, :]
+ std = np.array(cfg.pixel_std)[np.newaxis, np.newaxis, :]
+ img -= mean
+ img /= std
+ img = img.transpose((2, 0, 1))
+ return img, im_scale, need_gt_boxes, need_gt_label
+
+
+def _get_size_scale(w, h, min_size, max_size=None):
+ size = min_size
+ scale = 1.0
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = int(round(max_size * min_original_size / max_original_size))
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (h, w), scale
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ scale = size / w
+ else:
+ oh = size
+ ow = int(size * w / h)
+ scale = size / h
+ scale = ow / w
+ return (oh, ow), scale
+
+
+def _resize(im, target_size=800, max_size=1333):
+ if not isinstance(im, np.ndarray):
+ raise TypeError("{}: image type is not numpy.")
+ if len(im.shape) != 3:
+ raise ImageError('{}: image is not 3-dimensional.')
+ im_shape = im.shape
+ im_size_min = np.min(im_shape[0:2])
+ im_size_max = np.max(im_shape[0:2])
+ selected_size = target_size
+ if float(im_size_min) == 0:
+ raise ZeroDivisionError('min size of image is 0')
+ if max_size != 0:
+ im_scale = float(selected_size) / float(im_size_min)
+ # Prevent the biggest axis from being more than max_size
+ if np.round(im_scale * im_size_max) > max_size:
+ im_scale = float(max_size) / float(im_size_max)
+ im_scale_x = im_scale
+ im_scale_y = im_scale
+
+ resize_w = np.round(im_scale_x * float(im_shape[1]))
+ resize_h = np.round(im_scale_y * float(im_shape[0]))
+ im_info = [resize_h, resize_w, im_scale]
+ else:
+ im_scale_x = float(selected_size) / float(im_shape[1])
+ im_scale_y = float(selected_size) / float(im_shape[0])
+
+ resize_w = selected_size
+ resize_h = selected_size
+
+ im = Image.fromarray(im)
+ im = im.resize((int(resize_w), int(resize_h)), 2)
+ im = np.array(im)
+ return im, im_scale_x
+
+
+def _rotation(image,
+ gt_boxes,
+ gt_label,
+ prob,
+ fixed_angle=-1,
+ r_range=(360, 0),
+ gt_margin=1.4):
+ rotate_range = r_range[0]
+ shift = r_range[1]
+ angle = np.array([np.max([0, fixed_angle])])
+ if np.random.rand() <= prob:
+ angle = np.array(
+ np.random.rand(1) * rotate_range - shift, dtype=np.int16)
+ '''
+ rotate image
+ '''
+ image = np.array(image)
+ (h, w) = image.shape[:2]
+ scale = 1.0
+ # set the rotation center
+ center = (w / 2, h / 2)
+ # anti-clockwise angle in the function
+ M = cv2.getRotationMatrix2D(center, angle, scale)
+ image = cv2.warpAffine(image, M, (w, h))
+ # back to PIL image
+ im_width, im_height = w, h
+ '''
+ rotate boxes
+ '''
+ need_gt_boxes = gt_boxes.copy()
+ origin_gt_boxes = need_gt_boxes
+ rotated_gt_boxes = np.empty((len(need_gt_boxes), 5), dtype=np.float32)
+ # anti-clockwise to clockwise arc
+ cos_cita = np.cos(np.pi / 180 * angle)
+ sin_cita = np.sin(np.pi / 180 * angle)
+ # clockwise matrix
+ rotation_matrix = np.array([[cos_cita, sin_cita], [-sin_cita, cos_cita]])
+ pts_ctr = origin_gt_boxes[:, 0:2]
+ pts_ctr = pts_ctr - np.tile((im_width / 2, im_height / 2),
+ (gt_boxes.shape[0], 1))
+ pts_ctr = np.array(np.dot(pts_ctr, rotation_matrix), dtype=np.int16)
+ pts_ctr = np.squeeze(
+ pts_ctr, axis=-1) + np.tile((im_width / 2, im_height / 2),
+ (gt_boxes.shape[0], 1))
+ origin_gt_boxes[:, 0:2] = pts_ctr
+ len_of_gt = len(origin_gt_boxes)
+ # rectificate the angle in the range of [-45, 45]
+ for idx in range(len_of_gt):
+ ori_angle = origin_gt_boxes[idx, 4]
+ height = origin_gt_boxes[idx, 3]
+ width = origin_gt_boxes[idx, 2]
+ # step 1: normalize gt (-45,135)
+ if width < height:
+ ori_angle += 90
+ width, height = height, width
+ # step 2: rotate (-45,495)
+ rotated_angle = ori_angle + angle
+ # step 3: normalize rotated_angle (-45,135)
+ while rotated_angle > 135:
+ rotated_angle = rotated_angle - 180
+ rotated_gt_boxes[idx, 0] = origin_gt_boxes[idx, 0]
+ rotated_gt_boxes[idx, 1] = origin_gt_boxes[idx, 1]
+ rotated_gt_boxes[idx, 3] = height * gt_margin
+ rotated_gt_boxes[idx, 2] = width * gt_margin
+ rotated_gt_boxes[idx, 4] = rotated_angle
+ x_inbound = np.logical_and(rotated_gt_boxes[:, 0] >= 0,
+ rotated_gt_boxes[:, 0] < im_width)
+ y_inbound = np.logical_and(rotated_gt_boxes[:, 1] >= 0,
+ rotated_gt_boxes[:, 1] < im_height)
+ inbound = np.logical_and(x_inbound, y_inbound)
+ need_gt_boxes = rotated_gt_boxes[inbound]
+ need_gt_label = gt_label.copy()
+ need_gt_label = need_gt_label[inbound]
+ return image, need_gt_boxes, need_gt_label
+
+
+def prep_im_for_blob(im, pixel_means, target_size, max_size):
+ """Prepare an image for use as a network input blob. Specially:
+ - Subtract per-channel pixel mean
+ - Convert to float32
+ - Rescale to each of the specified target size (capped at max_size)
+ Returns a list of transformed images, one for each target size. Also returns
+ the scale factors that were used to compute each returned image.
+ """
+ im = im.astype(np.float32, copy=False)
+ im -= pixel_means
+
+ im_shape = im.shape
+ im_size_min = np.min(im_shape[0:2])
+ im_size_max = np.max(im_shape[0:2])
+ im_scale = float(target_size) / float(im_size_min)
+ # Prevent the biggest axis from being more than max_size
+ if np.round(im_scale * im_size_max) > max_size:
+ im_scale = float(max_size) / float(im_size_max)
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale,
+ fy=im_scale,
+ interpolation=cv2.INTER_LINEAR)
+ im_height, im_width, channel = im.shape
+ channel_swap = (2, 0, 1) #(batch, channel, height, width)
+ im = im.transpose(channel_swap)
+ return im, im_scale
diff --git a/PaddleCV/rrpn/edict.py b/PaddleCV/rrpn/edict.py
new file mode 100755
index 0000000000000000000000000000000000000000..552ede8e4006b5d4e90dd85d566749fd624c26d1
--- /dev/null
+++ b/PaddleCV/rrpn/edict.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+
+ def __getattr__(self, name):
+ if name in self.__dict__:
+ return self.__dict__[name]
+ elif name in self:
+ return self[name]
+ else:
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ if name in self.__dict__:
+ self.__dict__[name] = value
+ else:
+ self[name] = value
diff --git a/PaddleCV/rrpn/eval.py b/PaddleCV/rrpn/eval.py
new file mode 100755
index 0000000000000000000000000000000000000000..bf7732071967cab8766b9512c6007efb8e23db8a
--- /dev/null
+++ b/PaddleCV/rrpn/eval.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import cv2
+import time
+import numpy as np
+import pickle
+import paddle
+import paddle.fluid as fluid
+import reader
+import models.model_builder as model_builder
+import models.resnet as resnet
+import checkpoint as checkpoint
+from config import cfg
+from utility import print_arguments, parse_args, check_gpu
+from data_utils import DatasetPath
+from eval_helper import *
+import logging
+FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT)
+logger = logging.getLogger(__name__)
+
+
+def eval():
+
+ place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
+ class_nums = cfg.class_num
+ model = model_builder.RRPN(
+ add_conv_body_func=resnet.ResNet(),
+ add_roi_box_head_func=resnet.ResNetC5(),
+ use_pyreader=False,
+ mode='val')
+
+ startup_prog = fluid.Program()
+ infer_prog = fluid.Program()
+ with fluid.program_guard(infer_prog, startup_prog):
+ with fluid.unique_name.guard():
+ model.build_model(image_shape)
+ pred_boxes = model.eval_bbox_out()
+ infer_prog = infer_prog.clone(True)
+ exe.run(startup_prog)
+
+ # yapf: disable
+ def if_exist(var):
+ return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
+ if cfg.pretrained_model:
+ checkpoint.load_params(exe, infer_prog, cfg.pretrained_model)
+ # yapf: enable
+ test_reader = reader.test(1)
+ feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
+
+ fetch_list = [pred_boxes]
+ res_list = []
+ keys = [
+ 'bbox', 'gt_box', 'gt_class', 'is_crowed', 'im_info', 'im_id',
+ 'is_difficult'
+ ]
+ for i, data in enumerate(test_reader()):
+ im_info = [data[0][1]]
+ result = exe.run(infer_prog,
+ fetch_list=[v.name for v in fetch_list],
+ feed=feeder.feed(data),
+ return_numpy=False)
+ pred_boxes_v = result[0]
+ nmsed_out = pred_boxes_v
+ outs = np.array(nmsed_out)
+ res = get_key_dict(outs, data[0], keys)
+ res_list.append(res)
+ if i % 50 == 0:
+ logger.info('test_iter {}'.format(i))
+ icdar_eval(res_list)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ print_arguments(args)
+ check_gpu(args.use_gpu)
+ eval()
diff --git a/PaddleCV/rrpn/eval_helper.py b/PaddleCV/rrpn/eval_helper.py
new file mode 100755
index 0000000000000000000000000000000000000000..c9e66e67cbb740785cc8b1509006a750d7b0158f
--- /dev/null
+++ b/PaddleCV/rrpn/eval_helper.py
@@ -0,0 +1,379 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import numpy as np
+import paddle.fluid as fluid
+import math
+from config import cfg
+import six
+import numpy as np
+import cv2
+import Polygon as plg
+from PIL import Image
+from PIL import ImageDraw
+from PIL import ImageFont
+from config import cfg
+import logging
+logger = logging.getLogger(__name__)
+
+
+def get_key_dict(out, data, key):
+ res = {}
+ for i in range(len(key)):
+ if i == 0:
+ res[key[i]] = out
+ else:
+ res[key[i]] = data[i]
+ return res
+
+
+def get_labels_maps():
+ default_labels_maps = {1: 'text'}
+ if cfg.dataset == 'icdar2015':
+ return default_labels_maps
+
+ labels_map = {}
+ with open(os.path.join(cfg.data_dir, 'label_list')) as f:
+ lines = f.readlines()
+ for idx, line in enumerate(lines):
+ labels_map[idx + 1] = line.strip()
+ return labels_map
+
+
+def draw_bounding_box_on_image(image_path,
+ image_name,
+ nms_out,
+ im_scale,
+ draw_threshold=0.8):
+ #if image is None:
+ image = Image.open(os.path.join(image_path, image_name))
+ draw = ImageDraw.Draw(image)
+ im_width, im_height = image.size
+
+ labels_map = get_labels_maps()
+ for dt in np.array(nms_out):
+ num_id, score = dt.tolist()[:2]
+ x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist()[2:] / im_scale
+ if score < draw_threshold:
+ continue
+ draw.line(
+ [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
+ width=2,
+ fill='red')
+ if image.mode == 'RGB':
+ draw.text((x1, y1), labels_map[num_id], (255, 255, 0))
+ print("image with bbox drawed saved as {}".format(image_name))
+ image.save(image_name)
+
+
+def polygon_from_points(points):
+ """
+ Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
+ """
+ res_boxes = np.empty([1, 8], dtype='int32')
+ res_boxes[0, 0] = int(points[0])
+ res_boxes[0, 4] = int(points[1])
+ res_boxes[0, 1] = int(points[2])
+ res_boxes[0, 5] = int(points[3])
+ res_boxes[0, 2] = int(points[4])
+ res_boxes[0, 6] = int(points[5])
+ res_boxes[0, 3] = int(points[6])
+ res_boxes[0, 7] = int(points[7])
+ point_mat = res_boxes[0].reshape([2, 4]).T
+ return plg.Polygon(point_mat)
+
+
+def clip_box(bbox, im_info):
+
+ h = im_info[0]
+ w = im_info[1]
+ res = []
+ for b in bbox:
+ pts = b.reshape(4, 2)
+ pts[np.where(pts < 0)] = 1
+ pts[np.where(pts[:, 0] > w), 0] = w - 1
+ pts[np.where(pts[:, 1] > h), 1] = h - 1
+ pts = pts.reshape(-1)
+ pts /= im_info[2]
+ res.append(pts)
+
+ return np.array(res)
+
+
+def get_union(det, gt):
+ area_det = det.area()
+ area_gt = gt.area()
+ return area_det + area_gt - get_intersection(det, gt)
+
+
+def get_intersection_over_union(det, gt):
+ try:
+ return get_intersection(det, gt) / get_union(det, gt)
+ except:
+ return 0
+
+
+def get_intersection(det, gt):
+ inter = det & gt
+ if len(inter) == 0:
+ return 0
+ return inter.area()
+
+
+def parse_gt(result, im_id):
+ for res in result:
+ if res['im_id'] == im_id:
+ gt_boxes = list(res['gt_box'])
+ gt_class = res['gt_class']
+ is_difficult = res['is_difficult'].reshape(-1)
+ objects = []
+ for i in range(len(gt_boxes)):
+ object_struct = {}
+ object_struct['bbox'] = gt_boxes[i]
+ object_struct['class'] = gt_class[i]
+ if is_difficult[i] == 1:
+ object_struct['difficult'] = 1
+ else:
+ object_struct['difficult'] = 0
+ object_struct['im_id'] = im_id
+ objects.append(object_struct)
+ return objects
+
+
+def calculate_ap(rec, prec):
+ # 11 point metric
+ ap = 0.
+ for t in np.arange(0., 1.1, 0.1):
+ if np.sum(rec >= t) == 0:
+ p = 0
+ else:
+ p = np.max(prec[rec >= t])
+ ap = ap + p / 11.
+ return ap
+
+
+def icdar_map(result, class_name, ovthresh):
+ im_ids = []
+ for res in result:
+ im_ids.append(res['im_id'])
+ recs = {}
+
+ for i, im_id in enumerate(im_ids):
+ recs[str(im_id)] = parse_gt(result, im_id)
+ class_recs = {}
+ npos = 0
+ for k in im_ids:
+ res = [obj for obj in recs[str(k)] if obj['class'] == class_name]
+ bbox = np.array([x['bbox'] for x in res])
+ difficult = np.array([x['difficult'] for x in res]).astype(np.bool)
+ det = [False] * len(res)
+ npos = npos + sum(~difficult)
+ class_recs[k] = {'bbox': bbox, 'difficult': difficult, 'det': det}
+ image_ids = []
+ confidence = []
+ bbox = []
+ for res in result:
+ im_info = res['im_info']
+ pred_boxes = res['bbox']
+ for box in pred_boxes:
+ if box[0] == class_name:
+ image_ids.append(res['im_id'])
+ confidence.append(box[1])
+ clipd_box = clip_box(box[2:].reshape(-1, 8), im_info)
+ bbox.append(clipd_box[0])
+ confidence = np.array(confidence)
+ sorted_ind = np.argsort(-confidence)
+ sorted_scores = np.sort(-confidence)
+ bbox = np.array(bbox)
+ bbox = bbox[sorted_ind, :]
+ image_ids = [image_ids[x] for x in sorted_ind]
+ nd = len(image_ids)
+ tp = np.zeros(nd)
+ fp = np.zeros(nd)
+ for d in range(nd):
+ res = class_recs[image_ids[d]]
+ bb = bbox[d, :].astype(float)
+ ovmax = -np.inf
+ gt_bbox = res['bbox'].astype(float)
+ if gt_bbox.size > 0:
+ # compute overlaps
+ gt_bbox_xmin = np.min(gt_bbox[:, 0::2], axis=1)
+ gt_bbox_ymin = np.min(gt_bbox[:, 1::2], axis=1)
+ gt_bbox_xmax = np.max(gt_bbox[:, 0::2], axis=1)
+ gt_bbox_ymax = np.max(gt_bbox[:, 1::2], axis=1)
+ bb_xmin = np.min(bb[0::2])
+ bb_ymin = np.min(bb[1::2])
+ bb_xmax = np.max(bb[0::2])
+ bb_ymax = np.max(bb[1::2])
+
+ ixmin = np.maximum(gt_bbox_xmin, bb_xmin)
+ iymin = np.maximum(gt_bbox_ymin, bb_ymin)
+ ixmax = np.minimum(gt_bbox_xmax, bb_xmax)
+ iymax = np.minimum(gt_bbox_ymax, bb_ymax)
+ iw = np.maximum(ixmax - ixmin + 1., 0.)
+ ih = np.maximum(iymax - iymin + 1., 0.)
+ inters = iw * ih
+
+ # union
+ uni = ((bb_xmax - bb_xmin + 1.) * (bb_ymax - bb_ymin + 1.) +
+ (gt_bbox_xmax - gt_bbox_xmin + 1.) *
+ (gt_bbox_ymax - gt_bbox_ymin + 1.) - inters)
+
+ overlaps = inters / uni
+ gt_bbox_keep_mask = overlaps > 0
+ gt_bbox_keep = gt_bbox[gt_bbox_keep_mask, :]
+ gt_bbox_keep_index = np.where(overlaps > 0)[0]
+
+ def calcoverlaps(gt_bbox_keep, bb):
+ overlaps = []
+ for index, _ in enumerate(gt_bbox_keep):
+ p_g = polygon_from_points(gt_bbox_keep[index])
+ p_d = polygon_from_points(bb)
+ overlap = get_intersection_over_union(p_d, p_g)
+ overlaps.append(overlap)
+ return overlaps
+
+ if len(gt_bbox_keep) > 0:
+ overlaps = calcoverlaps(gt_bbox_keep, bb)
+
+ ovmax = np.max(overlaps)
+ jmax = np.argmax(overlaps)
+ jmax = gt_bbox_keep_index[jmax]
+
+ if ovmax > ovthresh:
+ if not res['difficult'][jmax]:
+ if not res['det'][jmax]:
+ tp[d] = 1.
+ res['det'][jmax] = 1
+ else:
+ fp[d] = 1.
+ else:
+ fp[d] = 1.
+ # compute precision recall
+ fp = np.cumsum(fp)
+ tp = np.cumsum(tp)
+
+ rec = tp / float(npos)
+ prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+ ap = calculate_ap(rec, prec)
+ return rec, prec, ap
+
+
+def icdar_map_eval(result, num_class):
+ map = 0
+ for i in range(num_class - 1):
+ rec, prec, ap = icdar_map(result, i + 1, ovthresh=0.5)
+ map = map + ap
+ map = map / (num_class - 1)
+ logger.info('mAP {}'.format(map))
+
+
+def icdar_box_eval(result, thresh):
+
+ matched_sum = 0
+ num_global_care_gt = 0
+ num_global_care_det = 0
+ for res in result:
+ im_info = res['im_info']
+ h = im_info[1]
+ w = im_info[2]
+ gt_boxes = res['gt_box']
+ pred_boxes = res['bbox']
+ pred_boxes = pred_boxes[np.where(pred_boxes[:, 1] > thresh)]
+ pred_boxes = pred_boxes[:, 2:]
+ pred_boxes = clip_box(pred_boxes, im_info)
+
+ is_difficult = res['is_difficult']
+ det_matched = 0
+
+ iou_mat = np.empty([1, 1])
+
+ gt_pols = []
+ det_pols = []
+
+ gt_pol_points = []
+ det_pol_points = []
+
+ gt_dont_care_pols_num = []
+ det_dont_care_pols_num = []
+
+ det_matched_nums = []
+
+ points_list = list(gt_boxes)
+
+ dony_care = is_difficult.reshape(-1)
+ for i, points in enumerate(points_list):
+ gt_pol = polygon_from_points(list(points))
+ gt_pols.append(gt_pol)
+ gt_pol_points.append(list(points))
+ if dony_care[i] == 1:
+ gt_dont_care_pols_num.append(len(gt_pols) - 1)
+ for i, points in enumerate(pred_boxes):
+ points = list(points.reshape(8).astype(np.int32))
+ det_pol = polygon_from_points(points)
+ det_pols.append(det_pol)
+ det_pol_points.append(points)
+ if len(gt_dont_care_pols_num) > 0:
+ for dont_care_pol in gt_dont_care_pols_num:
+ dont_care_pol = gt_pols[dont_care_pol]
+ intersected_area = get_intersection(dont_care_pol, det_pol)
+ pd_dimensions = det_pol.area()
+ precision = 0 if pd_dimensions == 0 else intersected_area / pd_dimensions
+ if (precision > 0.5):
+ det_dont_care_pols_num.append(len(det_pols) - 1)
+ break
+ if len(gt_pols) > 0 and len(det_pols) > 0:
+ # Calculate IoU and precision matrixs
+ output_shape = [len(gt_pols), len(det_pols)]
+ iou_mat = np.empty(output_shape)
+ gt_rect_mat = np.zeros(len(gt_pols), np.int8)
+ det_rect_mat = np.zeros(len(det_pols), np.int8)
+ for gt_num in range(len(gt_pols)):
+ for det_num in range(len(det_pols)):
+ p_d = gt_pols[gt_num]
+ p_g = det_pols[det_num]
+ iou_mat[gt_num, det_num] = get_intersection_over_union(p_d,
+ p_g)
+
+ for gt_num in range(len(gt_pols)):
+ for det_num in range(len(det_pols)):
+ if gt_rect_mat[gt_num] == 0 and det_rect_mat[
+ det_num] == 0 and gt_num not in gt_dont_care_pols_num and det_num not in det_dont_care_pols_num:
+ if iou_mat[gt_num, det_num] > 0.5:
+ gt_rect_mat[gt_num] = 1
+ det_rect_mat[det_num] = 1
+ det_matched += 1
+ det_matched_nums.append(det_num)
+ num_gt_care = (len(gt_pols) - len(gt_dont_care_pols_num))
+ num_det_care = (len(det_pols) - len(det_dont_care_pols_num))
+ matched_sum += det_matched
+ num_global_care_gt += num_gt_care
+ num_global_care_det += num_det_care
+ method_recall = 0 if num_global_care_gt == 0 else float(
+ matched_sum) / num_global_care_gt
+ method_precision = 0 if num_global_care_det == 0 else float(
+ matched_sum) / num_global_care_det
+ method_hmean = 0 if method_recall + method_precision == 0 else 2 * method_recall * method_precision / (
+ method_recall + method_precision)
+ logger.info('Recall {}'.format(method_recall))
+ logger.info('Precision {}'.format(method_precision))
+ logger.info('F1 {}'.format(method_hmean))
+
+
+def icdar_eval(result):
+ if cfg.dataset == 'icdar2015':
+ icdar_box_eval(result, 0.8)
+ else:
+ icdar_map_eval(result, cfg.class_num)
diff --git a/PaddleCV/rrpn/image/img_119.jpg b/PaddleCV/rrpn/image/img_119.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..01feb6de6bb67ecc39db6dfc041c01306e46a50a
Binary files /dev/null and b/PaddleCV/rrpn/image/img_119.jpg differ
diff --git a/PaddleCV/rrpn/image/img_120.jpg b/PaddleCV/rrpn/image/img_120.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8a318613b7275415599da6b79e33314b174e2172
Binary files /dev/null and b/PaddleCV/rrpn/image/img_120.jpg differ
diff --git a/PaddleCV/rrpn/infer.py b/PaddleCV/rrpn/infer.py
new file mode 100755
index 0000000000000000000000000000000000000000..3af9d21c2e2da456a5f719225c327633d97f6eb1
--- /dev/null
+++ b/PaddleCV/rrpn/infer.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import time
+import numpy as np
+import pickle
+import paddle
+import paddle.fluid as fluid
+import reader
+import models.model_builder as model_builder
+import models.resnet as resnet
+import checkpoint as checkpoint
+from config import cfg
+from data_utils import DatasetPath
+from eval_helper import *
+from utility import print_arguments, parse_args, check_gpu
+
+
+def infer():
+ place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
+ class_nums = cfg.class_num
+ model = model_builder.RRPN(
+ add_conv_body_func=resnet.ResNet(),
+ add_roi_box_head_func=resnet.ResNetC5(),
+ use_pyreader=False,
+ mode='infer')
+ startup_prog = fluid.Program()
+ infer_prog = fluid.Program()
+ with fluid.program_guard(infer_prog, startup_prog):
+ with fluid.unique_name.guard():
+ model.build_model(image_shape)
+ pred_boxes = model.eval_bbox_out()
+ infer_prog = infer_prog.clone(True)
+ exe.run(startup_prog)
+
+ # yapf: disable
+ def if_exist(var):
+ return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
+ if cfg.pretrained_model:
+ checkpoint.load_params(exe, infer_prog, cfg.pretrained_model)
+ # yapf: enable
+ infer_reader = reader.infer(cfg.image_path)
+ feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
+
+ fetch_list = [pred_boxes]
+ imgs = os.listdir(cfg.image_path)
+ imgs.sort()
+
+ for i, data in enumerate(infer_reader()):
+ result = exe.run(infer_prog,
+ fetch_list=[v.name for v in fetch_list],
+ feed=feeder.feed(data),
+ return_numpy=False)
+ nmsed_out = result[0]
+ im_info = data[0][1]
+ im_scale = im_info[2]
+ outs = np.array(nmsed_out)
+ draw_bounding_box_on_image(cfg.image_path, imgs[i], outs, im_scale,
+ cfg.draw_threshold)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ print_arguments(args)
+ check_gpu(args.use_gpu)
+ infer()
diff --git a/PaddleCV/rrpn/models/__init__.py b/PaddleCV/rrpn/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PaddleCV/rrpn/models/ext_op/rrpn_lib.py b/PaddleCV/rrpn/models/ext_op/rrpn_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..04c11486a5a4487a1ee8891c4141f29bd3aafdee
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/rrpn_lib.py
@@ -0,0 +1,549 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle.fluid as fluid
+from paddle.fluid.layer_helper import LayerHelper
+from paddle.fluid.framework import Variable
+fluid.load_op_library('models/ext_op/src/rrpn_lib.so')
+
+
+def rrpn_target_assign(bbox_pred,
+ cls_logits,
+ anchor_box,
+ gt_boxes,
+ im_info,
+ rpn_batch_size_per_im=256,
+ rpn_straddle_thresh=0.0,
+ rpn_fg_fraction=0.5,
+ rpn_positive_overlap=0.7,
+ rpn_negative_overlap=0.3,
+ use_random=True):
+ """
+ **Target Assign Layer for rotated region proposal network (RRPN).**
+ This layer can be, for given the Intersection-over-Union (IoU) overlap
+ between anchors and ground truth boxes, to assign classification and
+ regression targets to each each anchor, these target labels are used for
+ train RPN. The classification targets is a binary class label (of being
+ an object or not). Following the paper of RRPN, the positive labels
+ are two kinds of anchors: (i) the anchor/anchors with the highest IoU
+ overlap with a ground-truth box, or (ii) an anchor that has an IoU overlap
+ higher than rpn_positive_overlap(0.7) with any ground-truth box. Note
+ that a single ground-truth box may assign positive labels to multiple
+ anchors. A non-positive anchor is when its IoU ratio is lower than
+ rpn_negative_overlap (0.3) for all ground-truth boxes. Anchors that are
+ neither positive nor negative do not contribute to the training objective.
+ The regression targets are the encoded ground-truth boxes associated with
+ the positive anchors.
+ Args:
+ bbox_pred(Variable): A 3-D Tensor with shape [N, M, 5] represents the
+ predicted locations of M bounding bboxes. N is the batch size,
+ and each bounding box has five coordinate values and the layout
+ is [x, y, w, h, angle]. The data type can be float32 or float64.
+ cls_logits(Variable): A 3-D Tensor with shape [N, M, 1] represents the
+ predicted confidence predictions. N is the batch size, 1 is the
+ frontground and background sigmoid, M is number of bounding boxes.
+ The data type can be float32 or float64.
+ anchor_box(Variable): A 2-D Tensor with shape [M, 5] holds M boxes,
+ each box is represented as [x, y, w, h, angle],
+ [x, y] is the left top coordinate of the anchor box,
+ if the input is image feature map, they are close to the origin
+ of the coordinate system. [w, h] is the right bottom
+ coordinate of the anchor box, angle is the rotation angle of box.
+ The data type can be float32 or float64.
+ gt_boxes (Variable): The ground-truth bounding boxes (bboxes) are a 2D
+ LoDTensor with shape [Ng, 5], Ng is the total number of ground-truth
+ bboxes of mini-batch input. The data type can be float32 or float64.
+ im_info (Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
+ 3 is the height, width and scale.
+ rpn_batch_size_per_im(int): Total number of RPN examples per image.
+ The data type must be int32.
+ rpn_straddle_thresh(float): Remove RPN anchors that go outside the image
+ by straddle_thresh pixels. The data type must be float32.
+ rpn_fg_fraction(float): Target fraction of RoI minibatch that is labeled
+ foreground (i.e. class > 0), 0-th class is background. The data type must be float32.
+ rpn_positive_overlap(float): Minimum overlap required between an anchor
+ and ground-truth box for the (anchor, gt box) pair to be a positive
+ example. The data type must be float32.
+ rpn_negative_overlap(float): Maximum overlap allowed between an anchor
+ and ground-truth box for the (anchor, gt box) pair to be a negative
+ examples. The data type must be float32.
+ use_random(bool): Whether to sample randomly when sampling.
+ Returns:
+ tuple:
+ A tuple(predicted_scores, predicted_location, target_label,
+ target_bbox) is returned. The predicted_scores
+ and predicted_location is the predicted result of the RPN.
+ The target_label and target_bbox is the ground truth,
+ respectively. The predicted_location is a 2D Tensor with shape
+ [F, 5], and the shape of target_bbox is same as the shape of
+ the predicted_location, F is the number of the foreground
+ anchors. The predicted_scores is a 2D Tensor with shape
+ [F + B, 1], and the shape of target_label is same as the shape
+ of the predicted_scores, B is the number of the background
+ anchors, the F and B is depends on the input of this operator.
+ Bbox_inside_weight represents whether the predicted loc is fake_fg
+ or not and the shape is [F, 5].
+ Examples:
+ .. code-block:: python
+ import paddle.fluid as fluid
+ bbox_pred = fluid.data(name='bbox_pred', shape=[None, 5], dtype='float32')
+ cls_logits = fluid.data(name='cls_logits', shape=[None, 1], dtype='float32')
+ anchor_box = fluid.data(name='anchor_box', shape=[None, 5], dtype='float32')
+ gt_boxes = fluid.data(name='gt_boxes', shape=[None, 5], dtype='float32')
+ im_info = fluid.data(name='im_infoss', shape=[None, 3], dtype='float32')
+ loc, score, loc_target, score_target = rrpn_target_assign(
+ bbox_pred, cls_logits, anchor_box, gt_boxes, im_info)
+ """
+
+ helper = LayerHelper('rrpn_target_assign', **locals())
+ # Assign target label to anchors
+ loc_index = helper.create_variable_for_type_inference(dtype='int32')
+ score_index = helper.create_variable_for_type_inference(dtype='int32')
+ target_label = helper.create_variable_for_type_inference(dtype='int32')
+ target_bbox = helper.create_variable_for_type_inference(
+ dtype=anchor_box.dtype)
+ helper.append_op(
+ type="rrpn_target_assign",
+ inputs={'Anchor': anchor_box,
+ 'GtBoxes': gt_boxes,
+ 'ImInfo': im_info},
+ outputs={
+ 'LocationIndex': loc_index,
+ 'ScoreIndex': score_index,
+ 'TargetLabel': target_label,
+ 'TargetBBox': target_bbox
+ },
+ attrs={
+ 'rpn_batch_size_per_im': rpn_batch_size_per_im,
+ 'rpn_straddle_thresh': rpn_straddle_thresh,
+ 'rpn_positive_overlap': rpn_positive_overlap,
+ 'rpn_negative_overlap': rpn_negative_overlap,
+ 'rpn_fg_fraction': rpn_fg_fraction,
+ 'use_random': use_random
+ })
+
+ loc_index.stop_gradient = True
+ score_index.stop_gradient = True
+ target_label.stop_gradient = True
+ target_bbox.stop_gradient = True
+
+ cls_logits = fluid.layers.reshape(x=cls_logits, shape=(-1, 1))
+ bbox_pred = fluid.layers.reshape(x=bbox_pred, shape=(-1, 5))
+ predicted_cls_logits = fluid.layers.gather(cls_logits, score_index)
+ predicted_bbox_pred = fluid.layers.gather(bbox_pred, loc_index)
+
+ return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
+
+
+def rotated_anchor_generator(input,
+ anchor_sizes=None,
+ aspect_ratios=None,
+ angles=None,
+ variance=[1.0, 1.0, 1.0, 1.0, 1.0],
+ stride=None,
+ offset=0.5,
+ name=None):
+ """
+ **Rotated Anchor generator operator**
+ Generate anchors for RRPN algorithm.
+ Each position of the input produce N anchors, N =
+ size(anchor_sizes) * size(aspect_ratios) * size(angles).
+ The order of generated anchors is firstly aspect_ratios
+ loop then anchor_sizes loop.
+ Args:
+ input(Variable): 4-D Tensor with shape [N,C,H,W]. The input feature map.
+ anchor_sizes(float32|list|tuple): The anchor sizes of generated
+ anchors, given in absolute pixels e.g. [64., 128., 256., 512.].
+ For instance, the anchor size of 64 means the area of this anchor
+ equals to 64**2. None by default.
+ aspect_ratios(float32|list|tuple): The height / width ratios
+ of generated anchors, e.g. [0.5, 1.0, 2.0]. None by default.
+ angle(list|tuple): Rotated angle of prior boxes. The data type is float32.
+ variance(list|tuple): The variances to be used in box
+ regression deltas. The data type is float32, [1.0, 1.0, 1.0, 1.0, 1.0] by
+ default.
+ stride(list|tuple): The anchors stride across width and height.
+ The data type is float32. e.g. [16.0, 16.0]. None by default.
+ offset(float32): Prior boxes center offset. 0.5 by default.
+ name(str): Name of this layer. None by default.
+ Returns:
+ Anchors(Variable): The output anchors with a layout of [H, W, num_anchors, 5].
+ H is the height of input, W is the width of input,
+ num_anchors is the box count of each position. Each anchor is
+ in (x, y, w, h, angle) format.
+ Variances(Variable): The expanded variances of anchors with a layout of
+ [H, W, num_priors, 5]. H is the height of input,
+ W is the width of input num_anchors is the box count
+ of each position. Each variance is in (x, y, w, h, angle) format.
+ Examples:
+ .. code-block:: python
+ import paddle.fluid as fluid
+ conv1 = fluid.data(name='conv1', shape=[None, 48, 16, 16], dtype='float32')
+ anchor, var = rotated_anchor_generator(
+ input=conv1,
+ anchor_sizes=[128, 256, 512],
+ aspect_ratios=[0.2, 0.5, 1.0],
+ variance=[1.0, 1.0, 1.0, 1.0, 1.0],
+ stride=[16.0, 16.0],
+ offset=0.5)
+ """
+ helper = LayerHelper("rotated_anchor_generator", **locals())
+ dtype = helper.input_dtype()
+
+ def _is_list_or_tuple_(data):
+ return (isinstance(data, list) or isinstance(data, tuple))
+
+ if not _is_list_or_tuple_(anchor_sizes):
+ anchor_sizes = [anchor_sizes]
+ if not _is_list_or_tuple_(aspect_ratios):
+ aspect_ratios = [aspect_ratios]
+ if not _is_list_or_tuple_(angles):
+ angles = [angles]
+ if not (_is_list_or_tuple_(stride) and len(stride) == 2):
+ raise ValueError('stride should be a list or tuple ',
+ 'with length 2, (stride_width, stride_height).')
+
+ anchor_sizes = list(map(float, anchor_sizes))
+ aspect_ratios = list(map(float, aspect_ratios))
+ angles = list(map(float, angles))
+ stride = list(map(float, stride))
+
+ attrs = {
+ 'anchor_sizes': anchor_sizes,
+ 'aspect_ratios': aspect_ratios,
+ 'angles': angles,
+ 'variances': variance,
+ 'stride': stride,
+ 'offset': offset
+ }
+
+ anchor = helper.create_variable_for_type_inference(dtype)
+ var = helper.create_variable_for_type_inference(dtype)
+ helper.append_op(
+ type="rotated_anchor_generator",
+ inputs={"Input": input},
+ outputs={"Anchors": anchor,
+ "Variances": var},
+ attrs=attrs, )
+ anchor.stop_gradient = True
+ var.stop_gradient = True
+ return anchor, var
+
+
+def rrpn_box_coder(prior_box, prior_box_var, target_box, name=None):
+ """
+ Args:
+ prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
+ [M, 5] holds M boxes and data type is float32 or float64. Each box
+ is represented as [x, y, w, h, angle], [x, y] is the
+ center coordinate of the anchor box, [w, h] is the width and height
+ of the anchor box, angle is rotated angle of prior_box.
+ prior_box_var(List|Variable|None): "prior_box_var is a 2-D Tensor with
+ shape [M, 5] holds M group of variance."
+ target_box(Variable): This input can be a 2-D LoDTensor with shape
+ [M, 5]. Each box is represented as [x, y, w, h, angle]. The data
+ type is float32 or float64.
+ name(str): Name of this layer. None by default.
+ Returns:
+ Variable:
+ output_box(Variable): The output tensor of rrpn_box_coder_op with shape [N, 5] representing the
+ result of N target boxes encoded with N Prior boxes and variances.
+ N represents the number of box and 5 represents [x, y, w, h ,angle].
+ Examples:
+
+ .. code-block:: python
+
+ import paddle.fluid as fluid
+ prior_box_decode = fluid.data(name='prior_box_decode',
+ shape=[512, 5],
+ dtype='float32')
+ target_box_decode = fluid.data(name='target_box_decode',
+ shape=[512, 5],
+ dtype='float32')
+ output_decode = rrpn_box_coder(prior_box=prior_box_decode,
+ prior_box_var=[10, 10, 5, 5, 1],
+ target_box=target_box_decode)
+ """
+
+ helper = LayerHelper("rrpn_box_coder", **locals())
+
+ if name is None:
+ output_box = helper.create_variable_for_type_inference(
+ dtype=prior_box.dtype)
+ else:
+ output_box = helper.create_variable(
+ name=name, dtype=prior_box.dtype, persistable=False)
+
+ inputs = {"PriorBox": prior_box, "TargetBox": target_box}
+ attrs = {}
+ if isinstance(prior_box_var, Variable):
+ inputs['PriorBoxVar'] = prior_box_var
+ elif isinstance(prior_box_var, list):
+ attrs['variance'] = prior_box_var
+ else:
+ raise TypeError(
+ "Input variance of rrpn_box_coder must be Variable or list")
+ helper.append_op(
+ type="rrpn_box_coder",
+ inputs=inputs,
+ attrs=attrs,
+ outputs={"OutputBox": output_box})
+ return output_box
+
+
+def rotated_roi_align(input,
+ rois,
+ pooled_height=1,
+ pooled_width=1,
+ spatial_scale=1.0,
+ name=None):
+ """
+ **RotatedRoIAlign Operator**
+
+ Rotated Region of interest align (also known as Rotated RoI align) is to perform
+ bilinear interpolation on inputs of nonuniform sizes to obtain
+ fixed-size feature maps (e.g. 7*7)
+
+ Dividing each region proposal into equal-sized sections with
+ the pooled_width and pooled_height. Location remains the origin
+ result.
+
+ Each ROI bin are transformed to become horizontal by perspective transformation and
+ values in each ROI bin are computed directly through bilinear interpolation. The output is
+ the mean of all values.
+ Thus avoid the misaligned problem.
+ """
+ helper = LayerHelper('rrpn_rotated_roi_align', **locals())
+ dtype = helper.input_dtype()
+ align_out = helper.create_variable_for_type_inference(dtype)
+ cx = helper.create_variable_for_type_inference('float32')
+ cy = helper.create_variable_for_type_inference('float32')
+ helper.append_op(
+ type="rrpn_rotated_roi_align",
+ inputs={"X": input,
+ "ROIs": rois},
+ outputs={"Out": align_out,
+ "ConIdX": cx,
+ "ConIdY": cy},
+ attrs={
+ "pooled_height": pooled_height,
+ "pooled_width": pooled_width,
+ "spatial_scale": spatial_scale,
+ })
+ return align_out
+
+
+def rotated_generate_proposal_labels(rpn_rois,
+ gt_classes,
+ is_crowd,
+ gt_boxes,
+ im_info,
+ batch_size_per_im=256,
+ fg_fraction=0.25,
+ fg_thresh=0.25,
+ bg_thresh_hi=0.5,
+ bg_thresh_lo=0.0,
+ bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+ class_nums=None,
+ use_random=True,
+ is_cls_agnostic=False):
+ """
+ **Rotated Generate Proposal Labels**
+ This operator can be, for given the RotatedGenerateProposalOp output bounding boxes and groundtruth,
+ to sample foreground boxes and background boxes, and compute loss target.
+ RpnRois is the output boxes of RPN and was processed by rotated_generate_proposal_op, these boxes
+ were combined with groundtruth boxes and sampled according to batch_size_per_im and fg_fraction,
+ If an instance with a groundtruth overlap greater than fg_thresh, then it was considered as a foreground sample.
+ If an instance with a groundtruth overlap greater than bg_thresh_lo and lower than bg_thresh_hi,
+ then it was considered as a background sample.
+ After all foreground and background boxes are chosen (so called Rois),
+ then we apply random sampling to make sure
+ the number of foreground boxes is no more than batch_size_per_im * fg_fraction.
+ For each box in Rois, we assign the classification (class label) and regression targets (box label) to it.
+ Finally BboxInsideWeights and BboxOutsideWeights are used to specify whether it would contribute to training loss.
+ Args:
+ rpn_rois(Variable): A 2-D LoDTensor with shape [N, 5]. N is the number of the RotatedGenerateProposalOp's output, each element is a bounding box with [x, y, w, h, angle] format. The data type can be float32 or float64.
+ gt_classes(Variable): A 2-D LoDTensor with shape [M, 1]. M is the number of groundtruth, each element is a class label of groundtruth. The data type must be int32.
+ is_crowd(Variable): A 2-D LoDTensor with shape [M, 1]. M is the number of groundtruth, each element is a flag indicates whether a groundtruth is crowd. The data type must be int32.
+ gt_boxes(Variable): A 2-D LoDTensor with shape [M, 5]. M is the number of groundtruth, each element is a bounding box with [x, y, w, h, angle] format.
+ im_info(Variable): A 2-D LoDTensor with shape [B, 3]. B is the number of input images, each element consists of im_height, im_width, im_scale.
+ batch_size_per_im(int): Batch size of rois per images. The data type must be int32.
+ fg_fraction(float): Foreground fraction in total batch_size_per_im. The data type must be float32.
+ fg_thresh(float): Overlap threshold which is used to chose foreground sample. The data type must be float32.
+ bg_thresh_hi(float): Overlap threshold upper bound which is used to chose background sample. The data type must be float32.
+ bg_thresh_lo(float): Overlap threshold lower bound which is used to chose background sample. The data type must be float32.
+ bbox_reg_weights(list|tuple): Box regression weights. The data type must be float32.
+ class_nums(int): Class number. The data type must be int32.
+ use_random(bool): Use random sampling to choose foreground and background boxes.
+ is_cls_agnostic(bool): bbox regression use class agnostic simply which only represent fg and bg boxes.
+ Returns:
+ tuple:
+ A tuple with format``(rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights)``.
+ - **rois**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 5]``. The data type is the same as ``rpn_rois``.
+ - **labels_int32**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 1]``. The data type must be int32.
+ - **bbox_targets**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 5 * class_num]``. The regression targets of all RoIs. The data type is the same as ``rpn_rois``.
+ - **bbox_inside_weights**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 5 * class_num]``. The weights of foreground boxes' regression loss. The data type is the same as ``rpn_rois``.
+ - **bbox_outside_weights**: 2-D LoDTensor with shape ``[batch_size_per_im * batch_size, 5 * class_num]``. The weights of regression loss. The data type is the same as ``rpn_rois``.
+ Examples:
+ .. code-block:: python
+ import paddle.fluid as fluid
+ rpn_rois = fluid.data(name='rpn_rois', shape=[None, 5], dtype='float32')
+ gt_classes = fluid.data(name='gt_classes', shape=[None, 1], dtype='float32')
+ is_crowd = fluid.data(name='is_crowd', shape=[None, 1], dtype='float32')
+ gt_boxes = fluid.data(name='gt_boxes', shape=[None, 5], dtype='float32')
+ im_info = fluid.data(name='im_info', shape=[None, 3], dtype='float32')
+ rois, labels, bbox, inside_weights, outside_weights = rotated_generate_proposal_labels(
+ rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
+ class_nums=10)
+ """
+ helper = LayerHelper('rrpn_generate_proposal_labels', **locals())
+ rois = helper.create_variable_for_type_inference(dtype=rpn_rois.dtype)
+ labels_int32 = helper.create_variable_for_type_inference(
+ dtype=gt_classes.dtype)
+ bbox_targets = helper.create_variable_for_type_inference(
+ dtype=rpn_rois.dtype)
+ bbox_inside_weights = helper.create_variable_for_type_inference(
+ dtype=rpn_rois.dtype)
+ bbox_outside_weights = helper.create_variable_for_type_inference(
+ dtype=rpn_rois.dtype)
+
+ helper.append_op(
+ type="rrpn_generate_proposal_labels",
+ inputs={
+ 'RpnRois': rpn_rois,
+ 'GtClasses': gt_classes,
+ 'IsCrowd': is_crowd,
+ 'GtBoxes': gt_boxes,
+ 'ImInfo': im_info
+ },
+ outputs={
+ 'Rois': rois,
+ 'LabelsInt32': labels_int32,
+ 'BboxTargets': bbox_targets,
+ 'BboxInsideWeights': bbox_inside_weights,
+ 'BboxOutsideWeights': bbox_outside_weights
+ },
+ attrs={
+ 'batch_size_per_im': batch_size_per_im,
+ 'fg_fraction': fg_fraction,
+ 'fg_thresh': fg_thresh,
+ 'bg_thresh_hi': bg_thresh_hi,
+ 'bg_thresh_lo': bg_thresh_lo,
+ 'bbox_reg_weights': bbox_reg_weights,
+ 'class_nums': class_nums,
+ 'use_random': use_random,
+ 'is_cls_agnostic': is_cls_agnostic
+ })
+
+ rois.stop_gradient = True
+ labels_int32.stop_gradient = True
+ bbox_targets.stop_gradient = True
+ bbox_inside_weights.stop_gradient = True
+ bbox_outside_weights.stop_gradient = True
+
+ return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights
+
+
+def rotated_generate_proposals(scores,
+ bbox_deltas,
+ im_info,
+ anchors,
+ variances,
+ pre_nms_top_n=6000,
+ post_nms_top_n=1000,
+ nms_thresh=0.5,
+ min_size=0.1,
+ name=None):
+ """
+ **Rotated Generate proposal**
+ This operation proposes Rotated RoIs according to each box with their
+ probability to be a foreground object and the box can be calculated by anchors.
+ bbox_deltas and scores are the output of RPN. Final proposals could be used to
+ train detection net. For generating proposals, this operation performs following steps:
+ 1. Transposes and resizes scores and bbox_deltas in size of
+ (H*W*A, 1) and (H*W*A, 5)
+ 2. Calculate box locations as proposals candidates.
+ 3. Remove predicted boxes with small area.
+ 4. Apply NMS to get final proposals as output.
+ Args:
+ scores(Variable): A 4-D Tensor with shape [N, A, H, W] represents
+ the probability for each box to be an object.
+ N is batch size, A is number of anchors, H and W are height and
+ width of the feature map. The data type must be float32.
+ bbox_deltas(Variable): A 4-D Tensor with shape [N, 5*A, H, W]
+ represents the differece between predicted box locatoin and
+ anchor location. The data type must be float32.
+ im_info(Variable): A 2-D Tensor with shape [N, 3] represents origin
+ image information for N batch. Info contains height, width and scale
+ between origin image size and the size of feature map.
+ The data type must be int32.
+ anchors(Variable): A 4-D Tensor represents the anchors with a layout
+ of [H, W, A, 5]. H and W are height and width of the feature map,
+ num_anchors is the box count of each position. Each anchor is
+ in (x, y, w, h, angle) format. The data type must be float32.
+ variances(Variable): A 4-D Tensor. The expanded variances of anchors with a layout of
+ [H, W, num_priors, 5]. Each variance is in
+ (xcenter, ycenter, w, h) format. The data type must be float32.
+ pre_nms_top_n(float): Number of total bboxes to be kept per
+ image before NMS. The data type must be float32. `6000` by default.
+ post_nms_top_n(float): Number of total bboxes to be kept per
+ image after NMS. The data type must be float32. `1000` by default.
+ nms_thresh(float): Threshold in NMS. The data type must be float32. `0.5` by default.
+ min_size(float): Remove predicted boxes with either height or
+ width < min_size. The data type must be float32. `0.1` by default.
+ Returns:
+ tuple:
+ A tuple with format ``(rrpn_rois, rrpn_roi_probs)``.
+ - **rpn_rois**: The generated RoIs. 2-D Tensor with shape ``[N, 5]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``.
+ - **rpn_roi_probs**: The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``.
+ Examples:
+ .. code-block:: python
+
+ import paddle.fluid as fluid
+ scores = fluid.data(name='scores', shape=[None, 4, 5, 5], dtype='float32')
+ bbox_deltas = fluid.data(name='bbox_deltas', shape=[None, 20, 5, 5], dtype='float32')
+ im_info = fluid.data(name='im_info', shape=[None, 3], dtype='float32')
+ anchors = fluid.data(name='anchors', shape=[None, 5, 4, 5], dtype='float32')
+ variances = fluid.data(name='variances', shape=[None, 5, 10, 5], dtype='float32')
+ rrois, rroi_probs = fluid.layers.rotated_generate_proposals(scores, bbox_deltas,
+ im_info, anchors, variances)
+ """
+
+ helper = LayerHelper('rrpn_generate_proposals', **locals())
+
+ rpn_rois = helper.create_variable_for_type_inference(
+ dtype=bbox_deltas.dtype)
+ rpn_roi_probs = helper.create_variable_for_type_inference(
+ dtype=scores.dtype)
+ helper.append_op(
+ type="rrpn_generate_proposals",
+ inputs={
+ 'Scores': scores,
+ 'BboxDeltas': bbox_deltas,
+ 'ImInfo': im_info,
+ 'Anchors': anchors,
+ 'Variances': variances
+ },
+ attrs={
+ 'pre_nms_topN': pre_nms_top_n,
+ 'post_nms_topN': post_nms_top_n,
+ 'nms_thresh': nms_thresh,
+ 'min_size': min_size
+ },
+ outputs={'RpnRois': rpn_rois,
+ 'RpnRoiProbs': rpn_roi_probs})
+ rpn_rois.stop_gradient = True
+ rpn_roi_probs.stop_gradient = True
+
+ return rpn_rois, rpn_roi_probs
diff --git a/PaddleCV/rrpn/models/ext_op/src/README.md b/PaddleCV/rrpn/models/ext_op/src/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cbec185403d934909bb65f51b0c07ebbb99e9c03
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/README.md
@@ -0,0 +1,68 @@
+# 自定义OP的编译过程
+
+## 代码结构
+
+ - src: 扩展OP C++/CUDA 源码
+ - rrpn_lib.py: Python封装
+
+## 安装PaddlePaddle
+
+请通过如下方式安装PaddlePaddle:
+
+- 通过[Paddle develop分支](https://github.com/PaddlePaddle/Paddle/tree/develop)源码编译安装,编译方法如下:
+
+ 1. [Ubuntu](https://www.paddlepaddle.org.cn/install/doc/source/ubuntu)
+ 1. [CentOS](https://www.paddlepaddle.org.cn/install/doc/source/centos)
+ 1. [MasOS](https://www.paddlepaddle.org.cn/install/doc/source/macos)
+ 1. [Windows](https://www.paddlepaddle.org.cn/install/doc/source/windows)
+
+ **说明:** 推荐使用docker编译
+
+- 安装Paddle develop[每日版本whl包](https://www.paddlepaddle.org.cn/install/doc/tables#多版本whl包列表-dev-11)
+
+ **注意:** 编译自定义OP使用的gcc版本须与Paddle编译使用gcc版本一致,Paddle develop每日版本目前采用**gcc 4.8.2**版本编译,若使用每日版本,请使用**gcc 4.8.2**版本编译自定义OP,否则可能出现兼容性问题。
+
+## 编译自定义OP
+
+自定义op需要将实现的C++、CUDA代码编译成动态库,mask.sh中通过g++/nvcc编译,当然您也可以写Makefile或者CMake。
+
+编译需要include PaddlePaddle的相关头文件,链接PaddlePaddle的lib库。 头文件和lib库可通过下面命令获取到:
+
+```
+# python
+>>> import paddle
+>>> print(paddle.sysconfig.get_include())
+/paddle/pyenv/local/lib/python2.7/site-packages/paddle/include
+>>> print(paddle.sysconfig.get_lib())
+/paddle/pyenv/local/lib/python2.7/site-packages/paddle/libs
+```
+
+我们提供动态库编译脚本如下:
+
+```
+cd src
+sh make.sh
+```
+
+最终编译会产出`rrpn_lib.so`
+
+**说明:** 若使用源码编译安装PaddlePaddle的方式,编译过程中`cmake`未设置`WITH_MKLDNN`的方式,
+编译自定义OP时会报错找不到`mkldnn.h`等文件,可在`make.sh`中删除编译命令中的`-DPADDLE_WITH_MKLDNN`选项。
+
+## 设置环境变量
+
+需要将Paddle的核心库设置到`LD_LIBRARY_PATH`里, 先运行下面程序获取路径:
+
+```
+import paddle
+print(paddle.sysconfig.get_lib())
+```
+
+可通过如下方式添加动态库路径:
+
+```
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
+```
+
+
+更多关于如何在框架外部自定义 C++ OP,可阅读[官网说明文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/index_cn.html)
diff --git a/PaddleCV/rrpn/models/ext_op/src/bbox_util.h b/PaddleCV/rrpn/models/ext_op/src/bbox_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..dc978e2c1780312c22f9367bd6e75d32ce6a4867
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/bbox_util.h
@@ -0,0 +1,360 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+Based on
+--------------------------------------------------------
+@misc{ma2019rrpn,
+ author = {Jianqi Ma},
+ title = {{RRPN in pytorch}},
+ year = {2019},
+ howpublished = {\url{https://github.com/mjq11302010044/RRPN_pytorch}},
+}
+@article{Jianqi17RRPN,
+ Author = {Jianqi Ma and Weiyuan Shao and Hao Ye and Li Wang and Hong Wang
+and Yingbin Zheng and Xiangyang Xue},
+ Title = {Arbitrary-Oriented Scene Text Detection via Rotation Proposals},
+ journal = {IEEE Transactions on Multimedia},
+ volume={20},
+ number={11},
+ pages={3111-3122},
+ year={2018}
+}
+--------------------------------------------------------
+*/
+
+#pragma once
+#include
+#include "paddle/fluid/framework/eigen.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/tensor.h"
+
+namespace paddle {
+namespace operators {
+
+#define PI 3.141592654
+
+struct RangeInitFunctor {
+ int start;
+ int delta;
+ int* out;
+ HOSTDEVICE void operator()(size_t i) { out[i] = start + i * delta; }
+};
+
+
+// get trangle area after decompose intersecting polygons into triangles
+template
+inline T trangle_area(T* a, T* b, T* c) {
+ return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) * (b[0] - c[0])) / 2.0;
+}
+
+// get area of intersecting
+template
+inline T get_area(T* int_pts, int num_of_inter) {
+ T area = 0.0;
+ for (int i = 0; i < num_of_inter - 2; i++) {
+ area += fabs(
+ trangle_area(int_pts, int_pts + 2 * i + 2, int_pts + 2 * i + 4));
+ }
+ return area;
+}
+
+// sort points to decompose intersecting polygons into triangles
+template
+inline void reorder_pts(T* int_pts, int num_of_inter) {
+ if (num_of_inter > 0) {
+ T center[2] = {0.0, 0.0};
+
+ for (int i = 0; i < num_of_inter; i++) {
+ center[0] += int_pts[2 * i];
+ center[1] += int_pts[2 * i + 1];
+ }
+ center[0] /= num_of_inter;
+ center[1] /= num_of_inter;
+
+ T vs[16];
+ T v[2];
+ T d;
+ for (int i = 0; i < num_of_inter; i++) {
+ v[0] = int_pts[2 * i] - center[0];
+ v[1] = int_pts[2 * i + 1] - center[1];
+ d = sqrt(v[0] * v[0] + v[1] * v[1]);
+ v[0] = v[0] / d;
+ v[1] = v[1] / d;
+ if (v[1] < 0) {
+ v[0] = -2 - v[0];
+ }
+ vs[i] = v[0];
+ }
+
+ float temp, tx, ty;
+ int j;
+ for (int i = 1; i < num_of_inter; ++i) {
+ if (vs[i - 1] > vs[i]) {
+ temp = vs[i];
+ tx = int_pts[2 * i];
+ ty = int_pts[2 * i + 1];
+ j = i;
+ while (j > 0 && vs[j - 1] > temp) {
+ vs[j] = vs[j - 1];
+ int_pts[j * 2] = int_pts[j * 2 - 2];
+ int_pts[j * 2 + 1] = int_pts[j * 2 - 1];
+ j--;
+ }
+ vs[j] = temp;
+ int_pts[j * 2] = tx;
+ int_pts[j * 2 + 1] = ty;
+ }
+ }
+ }
+}
+
+// determine if points intersect
+template
+inline bool inter2line(T* pts1, T* pts2, int i, int j, T* temp_pts) {
+ T a[2] = {pts1[2 * i], pts1[2 * i + 1]};
+ T b[2] = {pts1[2 * ((i + 1) % 4)], pts1[2 * ((i + 1) % 4) + 1]};
+ T c[2] = {pts2[2 * j], pts2[2 * j + 1]};
+ T d[2] = {pts2[2 * ((j + 1) % 4)], pts2[2 * ((j + 1) % 4) + 1]};
+
+ T area_abc, area_abd, area_cda, area_cdb;
+
+ area_abc = trangle_area(a, b, c);
+ area_abd = trangle_area(a, b, d);
+
+ if (area_abc * area_abd >= -1e-5) {
+ return false;
+ }
+
+ area_cda = trangle_area(c, d, a);
+ area_cdb = area_cda + area_abc - area_abd;
+
+ if (area_cda * area_cdb >= -1e-5) {
+ return false;
+ }
+ T t = area_cda / (area_abd - area_abc);
+
+ T dx = t * (b[0] - a[0]);
+ T dy = t * (b[1] - a[1]);
+ temp_pts[0] = a[0] + dx;
+ temp_pts[1] = a[1] + dy;
+
+ return true;
+}
+
+template
+inline bool inrect(T pt_x, T pt_y, T* pts) {
+ T ab[2] = {pts[2] - pts[0], pts[3] - pts[1]};
+ T ad[2] = {pts[6] - pts[0], pts[7] - pts[1]};
+ T ap[2] = {pt_x - pts[0], pt_y - pts[1]};
+
+ T abab = ab[0] * ab[0] + ab[1] * ab[1];
+ T abap = ab[0] * ap[0] + ab[1] * ap[1];
+ T adad = ad[0] * ad[0] + ad[1] * ad[1];
+ T adap = ad[0] * ap[0] + ad[1] * ap[1];
+ bool result = (abab - abap >= -1) and (abap >= -1) and (adad - adap >= -1) and
+ (adap >= -1);
+ return result;
+}
+
+// calculate the number of intersection points
+template
+inline int inter_pts(T* pts1, T* pts2, T* int_pts) {
+ int num_of_inter = 0;
+
+ for (int i = 0; i < 4; i++) {
+ if (inrect(pts1[2 * i], pts1[2 * i + 1], pts2)) {
+ int_pts[num_of_inter * 2] = pts1[2 * i];
+ int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1];
+ num_of_inter++;
+ }
+ if (inrect(pts2[2 * i], pts2[2 * i + 1], pts1)) {
+ int_pts[num_of_inter * 2] = pts2[2 * i];
+ int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1];
+ num_of_inter++;
+ }
+ }
+
+ T out_pts[2];
+
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ bool has_pts = inter2line(pts1, pts2, i, j, out_pts);
+ if (has_pts) {
+ int_pts[num_of_inter * 2] = out_pts[0];
+ int_pts[num_of_inter * 2 + 1] = out_pts[1];
+ num_of_inter++;
+ }
+ }
+ }
+
+
+ return num_of_inter;
+}
+
+// convert x,y,w,h,angle to x1,y1,x2,y2,x3,y3,x4,y4
+template
+inline void convert_region(T* pts,
+ const framework::Tensor& _region,
+ int index) {
+ auto region = framework::EigenTensor::From(_region);
+ T angle = region(index, 4);
+ T a_cos = cos(angle / 180.0 * PI);
+ T a_sin = -sin(angle / 180.0 * PI); // anti clock-wise
+
+ T ctr_x = region(index, 0);
+ T ctr_y = region(index, 1);
+ T h = region(index, 3);
+ T w = region(index, 2);
+
+
+ T pts_x[4] = {-w / 2, -w / 2, w / 2, w / 2};
+ T pts_y[4] = {-h / 2, h / 2, h / 2, -h / 2};
+
+ for (int i = 0; i < 4; i++) {
+ pts[2 * i] = a_cos * pts_x[i] - a_sin * pts_y[i] + ctr_x;
+ pts[2 * i + 1] = a_sin * pts_x[i] + a_cos * pts_y[i] + ctr_y;
+ }
+}
+
+
+// Calculate the area of intersection
+template
+inline float inter(const framework::Tensor& _region1,
+ const framework::Tensor& _region2,
+ const int& r,
+ const int& c) {
+ T pts1[8];
+ T pts2[8];
+ T int_pts[16];
+ int num_of_inter;
+
+
+ convert_region(pts1, _region1, r);
+ convert_region(pts2, _region2, c);
+
+ num_of_inter = inter_pts(pts1, pts2, int_pts);
+
+ reorder_pts(int_pts, num_of_inter);
+
+ return get_area(int_pts, num_of_inter);
+}
+
+template
+inline float devRotateIoU(const framework::Tensor& _region1,
+ const framework::Tensor& _region2,
+ const int r,
+ const int c) {
+ auto __region1 = framework::EigenTensor::From(_region1);
+ auto __region2 = framework::EigenTensor::From(_region2);
+
+ if ((fabs(__region1(r, 0) - __region2(c, 0)) < 1e-5) &&
+ (fabs(__region1(r, 1) - __region2(c, 1)) < 1e-5) &&
+ (fabs(__region1(r, 2) - __region2(c, 2)) < 1e-5) &&
+ (fabs(__region1(r, 3) - __region2(c, 3)) < 1e-5) &&
+ (fabs(__region1(r, 4) - __region2(c, 4)) < 1e-5)) {
+ return 1.0;
+ }
+ T area1, area2, area_inter;
+ area1 = __region1(r, 2) * __region1(r, 3);
+ area2 = __region2(c, 2) * __region2(c, 3);
+ area_inter = inter(_region1, _region2, r, c);
+ auto result = area_inter / (area1 + area2 - area_inter);
+
+ if (result < 0) {
+ result = 0.0;
+ }
+ // may have bugs which cause overlap > 1
+ if (result > 1.00000001) {
+ result = 0.0;
+ }
+ return result;
+}
+
+
+template
+inline void BoxToDelta2(const int box_num,
+ const framework::Tensor& ex_boxes,
+ const framework::Tensor& gt_boxes,
+ const float* weights,
+ framework::Tensor* box_delta) {
+ auto ex_boxes_et = framework::EigenTensor::From(ex_boxes);
+ auto gt_boxes_et = framework::EigenTensor::From(gt_boxes);
+ auto trg = framework::EigenTensor::From(*box_delta);
+ T ex_w, ex_h, ex_ctr_x, ex_ctr_y, ex_angle, gt_w, gt_h, gt_ctr_x, gt_ctr_y,
+ gt_angle;
+ for (int64_t i = 0; i < box_num; ++i) {
+ ex_w = ex_boxes_et(i, 2);
+ ex_h = ex_boxes_et(i, 3);
+ ex_ctr_x = ex_boxes_et(i, 0);
+ ex_ctr_y = ex_boxes_et(i, 1);
+ ex_angle = ex_boxes_et(i, 4);
+
+ gt_w = gt_boxes_et(i, 2);
+ gt_h = gt_boxes_et(i, 3);
+ gt_ctr_x = gt_boxes_et(i, 0);
+ gt_ctr_y = gt_boxes_et(i, 1);
+ gt_angle = gt_boxes_et(i, 4);
+
+ trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w;
+ trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h;
+ trg(i, 2) = std::log(gt_w / ex_w);
+ trg(i, 3) = std::log(gt_h / ex_h);
+ trg(i, 4) = gt_angle - ex_angle;
+
+ if (weights) {
+ trg(i, 0) = trg(i, 0) * weights[0];
+ trg(i, 1) = trg(i, 1) * weights[1];
+ trg(i, 2) = trg(i, 2) * weights[2];
+ trg(i, 3) = trg(i, 3) * weights[3];
+ trg(i, 4) = trg(i, 4) * weights[4];
+ }
+
+ if (gt_angle <= -30 && ex_angle >= 120) {
+ trg(i, 4) = trg(i, 4) + 180.0;
+ }
+ if (gt_angle >= 120 && ex_angle <= -30) {
+ trg(i, 4) = trg(i, 4) - 180.0;
+ }
+ trg(i, 4) = (PI / 180) * trg(i, 4);
+ }
+}
+
+
+template
+void Gather(
+ const T* in, const int in_stride, const int* index, const int num, T* out) {
+ const int stride_bytes = in_stride * sizeof(T);
+ for (int i = 0; i < num; ++i) {
+ int id = index[i];
+ memcpy(out + i * in_stride, in + id * in_stride, stride_bytes);
+ }
+}
+
+template
+void BboxOverlaps2(const framework::Tensor& r_boxes,
+ const framework::Tensor& c_boxes,
+ framework::Tensor* overlaps) {
+ auto overlaps_et = framework::EigenTensor::From(*overlaps);
+ int r_num = r_boxes.dims()[0];
+ int c_num = c_boxes.dims()[0];
+ for (int i = 0; i < r_num; ++i) {
+ for (int j = 0; j < c_num; ++j) {
+ overlaps_et(i, j) = devRotateIoU(r_boxes, c_boxes, i, j);
+ }
+ }
+}
+
+
+} // namespace operators
+} // namespace paddle
diff --git a/PaddleCV/rrpn/models/ext_op/src/blas.h b/PaddleCV/rrpn/models/ext_op/src/blas.h
new file mode 100644
index 0000000000000000000000000000000000000000..5229882c36bb173109e3b44cb575afea6ecc9a9a
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/blas.h
@@ -0,0 +1,487 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/framework/tensor.h"
+
+#ifdef PADDLE_WITH_MKLML
+#include "paddle/fluid/platform/dynload/mklml.h"
+#endif
+
+#ifdef PADDLE_WITH_LIBXSMM
+#include
+#endif
+
+#ifdef PADDLE_USE_OPENBLAS
+#include
+#endif
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+/**
+ * Matrix Descriptor of a memory buffer.
+ *
+ * It is used for Blas::MatMul. MatMul operator can be batched.
+ * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
+ * `batch_size` times of GEMM. The batched GEMM could be faster base on the
+ * implementation of the blas library. The batch size could be zero. If any
+ * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
+ * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
+ * [BatchSize, H1, W2]
+ *
+ * The boolean flag, `trans`, describe the memory is the transpose of matrix or
+ * not. If the trans is true, the last two dims of matrix are transposed. The
+ * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
+ *
+ * The MatDescriptor is not only the dimension or shape of a matrix, it also
+ * contains the layout, stride of matrix. It is clearer to have a structure than
+ * reuse `DDim`.
+ */
+struct MatDescriptor {
+ int64_t height_;
+ int64_t width_;
+ int64_t stride_{0};
+ int64_t batch_size_{0};
+ bool trans_;
+};
+
+/**
+ * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
+ * flag
+ *
+ * @param tensor_dim: The dimension of the tensor. The rank of this dimension
+ * must larger than 1.
+ *
+ * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first
+ * dimension(column length) will be the product of tensor's first `num_col_dims`
+ * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
+ * batch_size of descriptor.
+ *
+ * @param trans: True if the matrix is transposed.
+ */
+extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
+ int num_flatten_cols,
+ bool trans);
+
+template
+class Blas {
+public:
+ explicit Blas(const DeviceContext& context) : context_(context) {}
+
+ template
+ void GEMM(CBLAS_TRANSPOSE transA,
+ CBLAS_TRANSPOSE transB,
+ int M,
+ int N,
+ int K,
+ T alpha,
+ const T* A,
+ const T* B,
+ T beta,
+ T* C) const;
+
+ template
+ void GEMM(bool transA,
+ bool transB,
+ int M,
+ int N,
+ int K,
+ T alpha,
+ const T* A,
+ int lda,
+ const T* B,
+ int ldb,
+ T beta,
+ T* C,
+ int ldc) const;
+
+ template
+ void GEMM(CBLAS_TRANSPOSE transA,
+ CBLAS_TRANSPOSE transB,
+ int M,
+ int N,
+ int K,
+ T alpha,
+ const T* A,
+ int lda,
+ const T* B,
+ int ldb,
+ T beta,
+ T* C,
+ int ldc) const;
+
+#ifdef PADDLE_WITH_MKLML
+ template
+ T* GEMM_ALLOC(const CBLAS_IDENTIFIER id,
+ const int M,
+ const int N,
+ const int K) const;
+
+ template
+ void GEMM_PACK(const CBLAS_IDENTIFIER id,
+ const CBLAS_TRANSPOSE trans,
+ int M,
+ int N,
+ int K,
+ const T alpha,
+ const T* src,
+ const int ld,
+ T* dst) const;
+
+ template
+ void GEMM_COMPUTE(int transA,
+ int transB,
+ int M,
+ int N,
+ int K,
+ const T* A,
+ const int lda,
+ const T* B,
+ const int ldb,
+ T beta,
+ T* C,
+ const int ldc) const;
+
+ template
+ void GEMM_FREE(T* data) const;
+
+ template
+ void CSRMM(const char* transa,
+ const int* m,
+ const int* n,
+ const int* k,
+ const T* alpha,
+ const char* matdescra,
+ const T* val,
+ const int* indx,
+ const int* pntrb,
+ const int* pntre,
+ const T* b,
+ const int* ldb,
+ const T* beta,
+ T* c,
+ const int* ldc) const;
+
+#if !defined(PADDLE_WITH_CUDA)
+ template
+ void MatMulWithHead(const framework::Tensor& mat_a,
+ const MatDescriptor& dim_a,
+ const framework::Tensor& mat_b,
+ const MatDescriptor& dim_b,
+ T alpha,
+ int head_number,
+ framework::Tensor* mat_out,
+ T beta,
+ bool mat_y_split_vertical) const;
+#endif
+#endif
+
+ template
+ void MatMul(const int M,
+ const int N,
+ const int K,
+ const T* A,
+ const T* B,
+ T* C) const;
+
+ template
+ void MatMul(const framework::Tensor& mat_a,
+ bool trans_a,
+ const framework::Tensor& mat_b,
+ bool trans_b,
+ T alpha,
+ framework::Tensor* mat_out,
+ T beta) const;
+
+ template
+ void MatMul(const framework::Tensor& mat_a,
+ bool trans_a,
+ const framework::Tensor& mat_b,
+ bool trans_b,
+ framework::Tensor* mat_out) const {
+ MatMul(mat_a,
+ trans_a,
+ mat_b,
+ trans_b,
+ static_cast(1.0),
+ mat_out,
+ static_cast(0.0));
+ }
+
+ template
+ void MatMul(const framework::Tensor& mat_a,
+ const framework::Tensor& mat_b,
+ framework::Tensor* mat_out) const {
+ this->template MatMul(mat_a, false, mat_b, false, mat_out);
+ }
+
+ template
+ void AXPY(int n, T alpha, const T* x, T* y) const;
+
+ template
+ void VADD(int n, const T* x, const T* y, T* z) const;
+
+ template
+ void VSUB(int n, const T* x, const T* y, T* z) const;
+
+ template
+ void VMUL(int n, const T* x, const T* y, T* z) const;
+
+ template
+ void VDIV(int n, const T* x, const T* y, T* z) const;
+
+ template
+ void VCOPY(int n, const T* x, T* y) const;
+
+ template
+ void VEXP(int n, const T* x, T* y) const;
+
+ template
+ void VSQUARE(int n, const T* x, T* y) const;
+
+ template
+ void VPOW(int n, const T* x, T alpha, T* y) const;
+
+ template
+ void GEMV(bool trans_a,
+ int M,
+ int N,
+ T alpha,
+ const T* A,
+ const T* B,
+ T beta,
+ T* C) const;
+
+ template
+ T DOT(int n, const T* x, const T* y) const;
+
+ template
+ void SCAL(int n, const T a, T* x) const;
+
+ template
+ T ASUM(int n, T* x, int inc) const;
+
+ template
+ void BatchedGEMM(CBLAS_TRANSPOSE transA,
+ CBLAS_TRANSPOSE transB,
+ int M,
+ int N,
+ int K,
+ T alpha,
+ const T* A,
+ const T* B,
+ T beta,
+ T* C,
+ int batchCount,
+ int64_t strideA,
+ int64_t strideB) const;
+
+#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
+ template
+ void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA,
+ CBLAS_TRANSPOSE transB,
+ int W1,
+ int H1,
+ int W2,
+ int H2,
+ T alpha,
+ const T* A,
+ const T* B,
+ T beta,
+ T* C,
+ int batchCount,
+ int64_t strideA,
+ int64_t strideB,
+ int64_t head_number,
+ bool split_b_vertical) const;
+#endif
+
+ template
+ void MatMul(const framework::Tensor& mat_a,
+ const MatDescriptor& dim_a,
+ const framework::Tensor& mat_b,
+ const MatDescriptor& dim_b,
+ T alpha,
+ framework::Tensor* mat_out,
+ T beta) const;
+
+ template
+ void VINV(int n, const T* a, T* y) const;
+
+ template
+ void VMERF(int n, const T* a, T* y, int64_t mode) const;
+
+private:
+ const DeviceContext& context_;
+};
+
+template
+class BlasT : private Blas {
+public:
+ using Blas::Blas;
+
+ template
+ void GEMM(ARGS... args) const {
+ Base()->template GEMM(args...);
+ }
+
+#ifdef PADDLE_WITH_MKLML
+ template
+ T* GEMM_ALLOC(ARGS... args) const {
+ return Base()->template GEMM_ALLOC(args...);
+ }
+
+ template
+ void GEMM_PACK(ARGS... args) const {
+ Base()->template GEMM_PACK(args...);
+ }
+
+ template
+ void GEMM_COMPUTE(ARGS... args) const {
+ Base()->template GEMM_COMPUTE(args...);
+ }
+
+ template
+ void GEMM_FREE(ARGS... args) const {
+ Base()->template GEMM_FREE(args...);
+ }
+
+ template
+ void CSRMM(ARGS... args) const {
+ Base()->template CSRMM(args...);
+ }
+
+#if !defined(PADDLE_WITH_CUDA)
+ template
+ void MatMulWithHead(ARGS... args) const {
+ Base()->template MatMulWithHead(args...);
+ }
+#endif
+#endif
+
+ template
+ void MatMul(ARGS... args) const {
+ Base()->template MatMul(args...);
+ }
+
+ template
+ void AXPY(ARGS... args) const {
+ Base()->template AXPY(args...);
+ }
+
+ template
+ void VADD(ARGS... args) const {
+ Base()->template VADD(args...);
+ }
+
+ template
+ void VSUB(ARGS... args) const {
+ Base()->template VSUB(args...);
+ }
+
+ template
+ void VMUL(ARGS... args) const {
+ Base()->template VMUL(args...);
+ }
+
+ template
+ void VDIV(ARGS... args) const {
+ Base()->template VDIV(args...);
+ }
+
+ template
+ void VCOPY(ARGS... args) const {
+ Base()->template VCOPY(args...);
+ }
+
+ template
+ void VEXP(ARGS... args) const {
+ Base()->template VEXP(args...);
+ }
+
+ template
+ void VSQUARE(ARGS... args) const {
+ Base()->template VSQUARE(args...);
+ }
+
+ template
+ void VPOW(ARGS... args) const {
+ Base()->template VPOW(args...);
+ }
+
+ template
+ void GEMV(ARGS... args) const {
+ Base()->template GEMV(args...);
+ }
+
+ template
+ T DOT(ARGS... args) const {
+ return Base()->template DOT(args...);
+ }
+
+ template
+ void SCAL(ARGS... args) const {
+ Base()->template SCAL(args...);
+ }
+
+ template
+ T ASUM(ARGS... args) const {
+ return Base()->template ASUM(args...);
+ }
+
+ template
+ void BatchedGEMM(ARGS... args) const {
+ Base()->template BatchedGEMM(args...);
+ }
+
+ template
+ void VINV(ARGS... args) const {
+ Base()->template VINV(args...);
+ }
+
+ template
+ void VMERF(ARGS... args) const {
+ Base()->template VMERF(args...);
+ }
+
+private:
+ const Blas* Base() const {
+ return static_cast*>(this);
+ }
+};
+
+template
+inline BlasT GetBlas(
+ const framework::ExecutionContext& exe_ctx) {
+ return BlasT(
+ exe_ctx.template device_context());
+}
+
+template
+inline BlasT GetBlas(const DeviceContext& dev_ctx) {
+ return BlasT(dev_ctx);
+}
+
+} // namespace math
+} // namespace operators
+} // namespace paddle
+
+#include "paddle/fluid/operators/math/blas_impl.h"
+#ifdef PADDLE_WITH_CUDA
+#include "paddle/fluid/operators/math/blas_impl.cu.h"
+#endif
diff --git a/PaddleCV/rrpn/models/ext_op/src/concat_and_split.cc b/PaddleCV/rrpn/models/ext_op/src/concat_and_split.cc
new file mode 100644
index 0000000000000000000000000000000000000000..20bf99637e22dd3a833d5f922d39f8b9d9dd4a87
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/concat_and_split.cc
@@ -0,0 +1,76 @@
+/* Copyright (c) 2019 paddlepaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "concat_and_split.h"
+#include
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+/*
+ * All tensors' dimension should be the same and the values of
+ * each dimension must be the same, except the axis dimension.
+ */
+template
+class ConcatFunctor {
+public:
+ void operator()(const platform::CPUDeviceContext& context,
+ const std::vector& input,
+ int axis,
+ framework::Tensor* output) {
+ // TODO(zcd): Add input data validity checking
+ int num = input.size();
+
+ int rows = 1;
+ auto dim_0 = input[0].dims();
+ for (int i = 0; i < axis; ++i) {
+ rows *= dim_0[i];
+ }
+ int out_rows = rows, out_cols = 0;
+
+ std::vector input_cols(input.size());
+ for (int i = 0; i < num; ++i) {
+ int t_cols = input[i].numel() / rows;
+ out_cols += t_cols;
+ input_cols[i] = t_cols;
+ }
+ auto cpu_place = boost::get(context.GetPlace());
+
+ // computation
+ auto output_data = output->data();
+ int col_idx = 0;
+ for (int j = 0; j < num; ++j) {
+ int col_len = input_cols[j];
+ auto input_data = input[j].data();
+ for (int k = 0; k < out_rows; ++k) {
+ memory::Copy(cpu_place,
+ output_data + k * out_cols + col_idx,
+ cpu_place,
+ input_data + k * col_len,
+ sizeof(T) * col_len);
+ }
+ col_idx += col_len;
+ }
+ }
+};
+
+#define DEFINE_FUNCTOR(type) \
+ template class ConcatFunctor;
+
+FOR_ALL_TYPES(DEFINE_FUNCTOR);
+
+} // namespace math
+} // namespace operators
+} // namespace paddle
diff --git a/PaddleCV/rrpn/models/ext_op/src/concat_and_split.h b/PaddleCV/rrpn/models/ext_op/src/concat_and_split.h
new file mode 100644
index 0000000000000000000000000000000000000000..d5947597f6bfcee10973aa41deea45053955e3b1
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/concat_and_split.h
@@ -0,0 +1,59 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include "paddle/fluid/framework/data_type.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+/*
+ * \brief Concatenate the input tensors along the dimension axis.
+ * TODO(zcd): maybe it needs to be more detailed.
+ * Examples:
+ * Input[0] = [[1,2],[3,4]]
+ * Input[1] = [[5,6]]
+ * axis = 0
+ *
+ * Output = [[1,2],
+ * [3,4],
+ * [5,6]]
+ */
+template
+class ConcatFunctor {
+public:
+ void operator()(const DeviceContext& context,
+ const std::vector& input,
+ int axis,
+ framework::Tensor* output);
+};
+
+
+} // namespace math
+} // namespace operators
+} // namespace paddle
+
+#define FOR_ALL_TYPES(macro) \
+ macro(int); \
+ macro(float); \
+ macro(double); \
+ macro(bool); \
+ macro(int64_t); \
+ macro(int16_t); \
+ macro(uint8_t); \
+ macro(int8_t); \
+ macro(::paddle::platform::float16)
diff --git a/PaddleCV/rrpn/models/ext_op/src/gather.cu.h b/PaddleCV/rrpn/models/ext_op/src/gather.cu.h
new file mode 100644
index 0000000000000000000000000000000000000000..9e6b76b37c4ccae3e3001589a20cf9db4c76fe4a
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/gather.cu.h
@@ -0,0 +1,125 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include "paddle/fluid/framework/dim.h"
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/framework/tensor.h"
+#include "paddle/fluid/memory/malloc.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+#include "paddle/fluid/platform/place.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+using platform::DeviceContext;
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+template
+__global__ void GatherCUDAKernel(const T* params,
+ const IndexT* indices,
+ T* output,
+ size_t index_size,
+ size_t slice_size) {
+ CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
+ int indices_i = i / slice_size;
+ int slice_i = i - indices_i * slice_size; // offset inside the slice
+ IndexT gather_i = indices[indices_i];
+ IndexT params_i = gather_i * slice_size + slice_i;
+ *(output + i) = *(params + params_i);
+ }
+}
+
+template
+__global__ void GatherNdCUDAKernel(const T* input,
+ const int* input_dims,
+ const IndexT* indices,
+ T* output,
+ size_t remain_size,
+ size_t slice_size,
+ size_t end_size) {
+ CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) {
+ int indices_i = i / slice_size;
+ int slice_i = i - indices_i * slice_size; // offset inside the slice
+ IndexT gather_i = 0;
+ int64_t temp = slice_size;
+ for (int64_t j = end_size - 1; j >= 0; --j) {
+ auto index_value = indices[indices_i * end_size + j];
+ assert(index_value >= 0 && index_value < input_dims[j]);
+ gather_i += (index_value * temp);
+ temp *= input_dims[j];
+ }
+ IndexT input_i = gather_i + slice_i;
+ *(output + i) = *(input + input_i);
+ }
+}
+
+/**
+ * A thin wrapper on gpu tensor
+ * Return a new tensor from source tensor, gathered according to index
+ * input[src]: type-T source Tensor
+ * input[index]: type-IndexT index Tensor (1-D)
+ * return: output tensor
+ */
+template
+void GPUGather(const platform::DeviceContext& ctx,
+ const Tensor& src,
+ const Tensor& index,
+ Tensor* output) {
+ // check index of shape 1-D
+ if (index.dims().size() == 1) {
+ PADDLE_ENFORCE_GT(index.dims()[0],
+ 0,
+ "The index of gather_op should not be empty when the "
+ "index's rank is 1.");
+ } else if (index.dims().size() == 2) {
+ PADDLE_ENFORCE_EQ(index.dims()[1],
+ 1,
+ " If the index's rank of gather_op is 2, the second "
+ "dimension should be 1.");
+ }
+
+ int index_size = index.dims()[0];
+
+ auto src_dims = src.dims();
+ framework::DDim output_dims(src_dims);
+ output_dims[0] = index_size;
+
+ // slice size
+ int slice_size = 1;
+ for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
+
+ const T* p_src = src.data();
+ const IndexT* p_index = index.data();
+ T* p_output = output->data();
+
+ int block = 512;
+ int n = slice_size * index_size;
+ int grid = (n + block - 1) / block;
+
+ GatherCUDAKernel<<<
+ grid,
+ block,
+ 0,
+ reinterpret_cast(ctx).stream()>>>(
+ p_src, p_index, p_output, index_size, slice_size);
+}
+
+} // namespace operators
+} // namespace paddle
diff --git a/PaddleCV/rrpn/models/ext_op/src/gather.h b/PaddleCV/rrpn/models/ext_op/src/gather.h
new file mode 100644
index 0000000000000000000000000000000000000000..a2ee07427724a7850b009402a8a6b1255d2e8a79
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/gather.h
@@ -0,0 +1,74 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+
+#include "paddle/fluid/framework/ddim.h"
+#include "paddle/fluid/framework/eigen.h"
+#include "paddle/fluid/framework/tensor.h"
+#include "paddle/fluid/platform/place.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+
+/**
+ * A thin wrapper for gathering on cpu tensor
+ * Return a new tensor from source tensor, gathered according to index
+ * input[src]: type-T source Tensor
+ * input[index]: type-IndexT index Tensor (1-D)
+ * return: output tensor
+ */
+template
+void CPUGather(const platform::DeviceContext& ctx,
+ const Tensor& src,
+ const Tensor& index,
+ Tensor* output) {
+ PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
+ // check index of shape 1-D
+ if (index.dims().size() == 2) {
+ PADDLE_ENFORCE_EQ(index.dims()[1],
+ 1,
+ "index.dims()[1] should be 1 when index.dims().size() == "
+ "2 in gather_op.");
+ } else {
+ PADDLE_ENFORCE_EQ(index.dims().size(),
+ 1,
+ "index.dims().size() should be 1 or 2 in gather_op.");
+ }
+ int64_t index_size = index.dims()[0];
+
+ auto src_dims = src.dims();
+
+ const T* p_src = src.data();
+ const IndexT* p_index = index.data();
+ T* p_output = output->data();
+
+ // slice size
+ int slice_size = 1;
+ for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
+
+ const size_t slice_bytes = slice_size * sizeof(T);
+
+ for (int64_t i = 0; i < index_size; ++i) {
+ IndexT index_ = p_index[i];
+ memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
+ }
+}
+
+} // namespace operators
+} // namespace paddle
diff --git a/PaddleCV/rrpn/models/ext_op/src/make.sh b/PaddleCV/rrpn/models/ext_op/src/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..96810820da26e00c7aa23a6a103f6dc3401ce655
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/make.sh
@@ -0,0 +1,73 @@
+include_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_include())' )
+lib_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_lib())' )
+
+echo $include_dir
+echo $lib_dir
+
+CUDA=$1
+CUDNN=$2
+NCCL=$3
+
+if [ ! -d "$CUDA" ]; then
+echo "Usage: sh make.sh \$CUDA_PATH \$CUDNN_PATH \$NCCL_PATH"
+exit
+fi
+
+if [ ! -d "$CUDNN" ]; then
+echo "Usage: sh make.sh \${CUDA_PATH} \${CUDNN_PATH} \${NCCL_PATH}"
+exit
+fi
+
+if [ ! -d "$NCCL" ]; then
+echo "Usage: sh make.sh \${CUDA_PATH} \${CUDNN_PATH} \${NCCL_PATH}"
+exit
+fi
+
+git clone https://github.com/NVlabs/cub.git
+
+nvcc rrpn_generate_proposals_op.cu -c -o rrpn_generate_proposals_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
+ -I ${include_dir} \
+ -I ${include_dir}/third_party \
+ -I ${CUDA}/include \
+ -I ${CUDNN}/include \
+ -I ${NCCL}/include \
+ -L ${lib_dir} -lpaddle_framework \
+ -L ${CUDA}/lib64 -lcudart
+
+
+nvcc rotated_anchor_generator_op.cu -c -o rotated_anchor_generator_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
+ -I ${include_dir} \
+ -I ${include_dir}/third_party \
+ -I ${CUDA}/include \
+ -I ${CUDNN}/include \
+ -I ${NCCL}/include \
+ -L ${lib_dir} -lpaddle_framework \
+ -L ${CUDA}/lib64 -lcudart
+
+nvcc rrpn_box_coder_op.cu -c -o rrpn_box_coder_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
+ -I ${include_dir} \
+ -I ${include_dir}/third_party \
+ -I ${CUDA}/include \
+ -I ${CUDNN}/include \
+ -I ${NCCL}/include \
+ -L ${lib_dir} -lpaddle_framework \
+ -L ${CUDA}/lib64 -lcudart
+
+nvcc rrpn_rotated_roi_align_op.cu -c -o rrpn_rotated_roi_align_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
+ -I ${include_dir} \
+ -I ${include_dir}/third_party \
+ -I ${CUDA}/include \
+ -I ${CUDNN}/include \
+ -I ${NCCL}/include \
+ -L ${lib_dir} -lpaddle_framework \
+ -L ${CUDA}/lib64 -lcudart
+
+
+g++ rotated_anchor_generator_op.cc concat_and_split.cc rrpn_generate_proposal_labels_op.cc rrpn_generate_proposals_op.cc rrpn_target_assign_op.cc rrpn_box_coder_op.cc rrpn_rotated_roi_align_op.cc rrpn_rotated_roi_align_op.cu.o rrpn_box_coder_op.cu.o rotated_anchor_generator_op.cu.o rrpn_generate_proposals_op.cu.o -o rrpn_lib.so -shared -fPIC -std=c++11 -O3 -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO \
+ -I ${include_dir} \
+ -I ${include_dir}/third_party \
+ -I ${CUDA}/include \
+ -I ${CUDNN}/include \
+ -I ${NCCL}/include \
+ -L ${lib_dir} -lpaddle_framework \
+ -L ${CUDA}/lib64 -lcudart
diff --git a/PaddleCV/rrpn/models/ext_op/src/math_function.cc b/PaddleCV/rrpn/models/ext_op/src/math_function.cc
new file mode 100644
index 0000000000000000000000000000000000000000..24d5909eafa2293b1474f4c4e22f1ab177d610ac
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/math_function.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "math_function.h"
+
+#ifdef PADDLE_WITH_MKLML
+#include "paddle/fluid/platform/dynload/mklml.h"
+#endif
+
+#ifdef PADDLE_USE_OPENBLAS
+#include
+#endif
+
+#include
+#include "math_function_impl.h"
+#include "paddle/fluid/framework/data_type.h"
+#include "paddle/fluid/platform/float16.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+#define DEFINE_CPU_TRANS(RANK) \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose; \
+ template struct Transpose;
+
+DEFINE_CPU_TRANS(1);
+DEFINE_CPU_TRANS(2);
+DEFINE_CPU_TRANS(3);
+DEFINE_CPU_TRANS(4);
+DEFINE_CPU_TRANS(5);
+DEFINE_CPU_TRANS(6);
+
+template
+void Transpose::operator()(
+ const DeviceContext& context,
+ const framework::Tensor& in,
+ framework::Tensor* out,
+ const std::vector& axis) {
+ Eigen::array permute;
+ for (int i = 0; i < Rank; i++) {
+ permute[i] = axis[i];
+ }
+ auto eigen_in = framework::EigenTensor::From(in);
+ auto eigen_out = framework::EigenTensor::From(*out);
+ auto* dev = context.eigen_device();
+ eigen_out.device(*dev) = eigen_in.shuffle(permute);
+}
+
+
+} // namespace math
+} // namespace operators
+} // namespace paddle
diff --git a/PaddleCV/rrpn/models/ext_op/src/math_function.h b/PaddleCV/rrpn/models/ext_op/src/math_function.h
new file mode 100644
index 0000000000000000000000000000000000000000..b8043943ed4c3f0b26ca92a2d1ab8e091213673b
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/math_function.h
@@ -0,0 +1,43 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+
+#include "paddle/fluid/framework/eigen.h"
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/framework/tensor.h"
+#include "paddle/fluid/framework/tensor_util.h"
+#include "paddle/fluid/platform/device_context.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+template
+struct Transpose {
+ void operator()(const DeviceContext& context,
+ const framework::Tensor& in,
+ framework::Tensor* out,
+ const std::vector& axis);
+};
+
+void set_constant(const platform::DeviceContext& context,
+ framework::Tensor* tensor,
+ float value);
+
+} // namespace math
+} // namespace operators
+} // namespace paddle
diff --git a/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.cc b/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..854245aaa2ce80024d608fb76fda001b48a505ac
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.cc
@@ -0,0 +1,172 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "rotated_anchor_generator_op.h"
+
+namespace paddle {
+namespace operators {
+
+class RotatedAnchorGeneratorOp : public framework::OperatorWithKernel {
+public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(
+ ctx->HasInput("Input"),
+ "Input(Input) of RotatedAnchorGeneratorOp should not be null.");
+ PADDLE_ENFORCE(
+ ctx->HasOutput("Anchors"),
+ "Output(Anchors) of RotatedAnchorGeneratorOp should not be null.");
+ PADDLE_ENFORCE(
+ ctx->HasOutput("Variances"),
+ "Output(Variances) of RotatedAnchorGeneratorOp should not be null.");
+
+ auto input_dims = ctx->GetInputDim("Input");
+ PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
+
+ auto anchor_sizes = ctx->Attrs().Get>("anchor_sizes");
+ auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios");
+ auto angles = ctx->Attrs().Get>("angles");
+ auto stride = ctx->Attrs().Get>("stride");
+ auto variances = ctx->Attrs().Get>("variances");
+
+ size_t num_anchors =
+ aspect_ratios.size() * anchor_sizes.size() * angles.size();
+
+ std::vector dim_vec(4);
+ dim_vec[0] = input_dims[2];
+ dim_vec[1] = input_dims[3];
+ dim_vec[2] = num_anchors;
+ dim_vec[3] = 5;
+ ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec));
+ ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
+ }
+
+protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ ctx.Input("Input")->type(), ctx.device_context());
+ }
+};
+
+class RotatedAnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker {
+public:
+ void Make() override {
+ AddInput("Input",
+ "(Tensor, default Tensor), "
+ "the input feature is a tensor with a rank of 4. "
+ "The layout is NCHW.");
+ AddOutput("Anchors",
+ "(Tensor, default Tensor), the output is a "
+ "tensor with a rank of 4. The layout is [H, W, num_anchors, 5]. "
+ "H is the height of input, W is the width of input, num_anchors "
+ "is the box count of each position. "
+ "Each anchor is in (xctr, yctr, w, h, thelta) format");
+ AddOutput("Variances",
+ "(Tensor, default Tensor), the expanded variances for "
+ "normalizing bbox regression targets. The layout is [H, W, "
+ "num_anchors, 5]. "
+ "H is the height of input, W is the width of input, num_anchors "
+ "is the box count of each position. "
+ "Each variance is in (xctr, yctr, w, h, thelta) format");
+
+ AddAttr>(
+ "anchor_sizes",
+ "(vector) List of Rotated Region Proposal Network(RRPN) anchor "
+ "sizes "
+ " given in absolute pixels e.g. (64, 128, 256, 512)."
+ " For instance, the anchor size of 64 means the area of this anchor "
+ "equals to 64**2.")
+ .AddCustomChecker([](const std::vector& anchor_sizes) {
+ PADDLE_ENFORCE_GT(anchor_sizes.size(),
+ 0UL,
+ "Size of anchor_sizes must be at least 1.");
+ for (size_t i = 0; i < anchor_sizes.size(); ++i) {
+ PADDLE_ENFORCE_GT(
+ anchor_sizes[i], 0.0, "anchor_sizes[%d] must be positive.", i);
+ }
+ });
+ AddAttr>(
+ "aspect_ratios",
+ "(vector) List of Rotated Region Proposal Network(RRPN) anchor "
+ "aspect "
+ "ratios, e.g. (0.5, 1, 2)."
+ "For instacne, the aspect ratio of 0.5 means the height / width of "
+ "this anchor equals 0.5.");
+ AddAttr>(
+ "angles",
+ "(vector) List of Rotated Region Proposal Network(RRPN) anchor "
+ "angles, "
+ "e.g. (-30.0, 0.0, 30.0, 60.0, 90.0, 120.0)."
+ "For instacne, the aspect ratio of 0.5 means the height / width of "
+ "this anchor equals 0.5.");
+
+ AddAttr>("variances",
+ "(vector) List of variances to be used "
+ "in box regression deltas")
+ .AddCustomChecker([](const std::vector& variances) {
+ PADDLE_ENFORCE_EQ(
+ variances.size(), 5UL, "Must and only provide 5 variance.");
+ for (size_t i = 0; i < variances.size(); ++i) {
+ PADDLE_ENFORCE_GT(
+ variances[i], 0.0, "variance[%d] must be greater than 0.", i);
+ }
+ });
+
+ AddAttr>("stride",
+ "Anchors stride across width and height, "
+ "with a default of (16, 16)")
+ .SetDefault(std::vector(2, 16.0))
+ .AddCustomChecker([](const std::vector& stride) {
+ PADDLE_ENFORCE_EQ(
+ stride.size(),
+ 2UL,
+ "Must and only provide 2 stride for width and height.");
+ for (size_t i = 0; i < stride.size(); ++i) {
+ PADDLE_ENFORCE_GT(
+ stride[i], 0.0, "stride[%d] should be larger than 0.", i);
+ }
+ });
+
+ AddAttr("offset",
+ "(float) "
+ "Anchor center offset, with a default of 0.5")
+ .SetDefault(0.5);
+ AddComment(R"DOC(
+RotatedAnchorGenerator operator
+Generates anchors for RRPN. algorithm.
+Each position of the input produce N anchors, N =
+ size(anchor_sizes) * size(aspect_ratios) * size(angles).
+
+Please get more information from the following papers:
+https://arxiv.org/abs/1703.01086.
+)DOC");
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(
+ rotated_anchor_generator,
+ ops::RotatedAnchorGeneratorOp,
+ ops::RotatedAnchorGeneratorOpMaker,
+ paddle::framework::EmptyGradOpMaker,
+ paddle::framework::EmptyGradOpMaker);
+
+REGISTER_OP_CPU_KERNEL(rotated_anchor_generator,
+ ops::RotatedAnchorGeneratorOpKernel,
+ ops::RotatedAnchorGeneratorOpKernel);
diff --git a/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.cu b/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9c5250103326baa08588c3acacc7c1767da18332
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.cu
@@ -0,0 +1,153 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "rotated_anchor_generator_op.h"
+
+namespace paddle {
+namespace operators {
+
+template
+__global__ void GenRAnchors(T* out,
+ const T* aspect_ratios,
+ const int ar_num,
+ const T* anchor_sizes,
+ const int as_num,
+ const T* angles,
+ const int aa_num,
+ const T* stride,
+ const int sd_num,
+ const int height,
+ const int width,
+ const T offset) {
+ int num_anchors = as_num * ar_num * aa_num;
+ int box_num = height * width * num_anchors;
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
+ i += blockDim.x * gridDim.x) {
+ int h_idx = i / (num_anchors * width);
+ int w_idx = (i / num_anchors) % width;
+ T stride_width = stride[0];
+ T stride_height = stride[1];
+ T x_ctr = (w_idx * stride_width) + offset * stride_width - 1;
+ T y_ctr = (h_idx * stride_height) + offset * stride_height - 1;
+ T area, area_ratios;
+ T base_w, base_h;
+ T scale_w, scale_h;
+ T anchor_width, anchor_height;
+ int anch_idx = i % num_anchors;
+ int ar_idx = anch_idx / (as_num * aa_num);
+ int as_idx = anch_idx / aa_num % as_num;
+ int aa_idx = anch_idx % aa_num;
+ T aspect_ratio = aspect_ratios[ar_idx];
+ T anchor_size = anchor_sizes[as_idx];
+ T angle = angles[aa_idx];
+ area = stride_width * stride_height;
+ area_ratios = area / aspect_ratio;
+ base_w = round(sqrt(area_ratios));
+ base_h = round(base_w * aspect_ratio);
+ scale_w = anchor_size / stride_width;
+ scale_h = anchor_size / stride_height;
+ anchor_width = scale_w * base_w;
+ anchor_height = scale_h * base_h;
+ out[i * 5] = x_ctr;
+ out[i * 5 + 1] = y_ctr;
+ out[i * 5 + 2] = anchor_width;
+ out[i * 5 + 3] = anchor_height;
+ out[i * 5 + 4] = angle;
+ }
+}
+
+template
+__global__ void SetVariance(T* out,
+ const T* var,
+ const int vnum,
+ const int num) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
+ i += blockDim.x * gridDim.x) {
+ out[i] = var[i % vnum];
+ }
+}
+
+template
+class RotatedAnchorGeneratorOpCUDAKernel : public framework::OpKernel {
+public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* input = ctx.Input("Input");
+ auto* anchors = ctx.Output("Anchors");
+ auto* vars = ctx.Output("Variances");
+
+ auto anchor_sizes = ctx.Attr>("anchor_sizes");
+ auto aspect_ratios = ctx.Attr>("aspect_ratios");
+ auto angles = ctx.Attr>("angles");
+ auto stride = ctx.Attr>("stride");
+ auto variances = ctx.Attr>("variances");
+
+ T offset = static_cast(ctx.Attr("offset"));
+
+ auto width = input->dims()[3];
+ auto height = input->dims()[2];
+
+ int num_anchors =
+ aspect_ratios.size() * anchor_sizes.size() * angles.size();
+
+ int box_num = width * height * num_anchors;
+
+ int block = 512;
+ int grid = (box_num + block - 1) / block;
+
+ auto stream =
+ ctx.template device_context().stream();
+
+ anchors->mutable_data(ctx.GetPlace());
+ vars->mutable_data(ctx.GetPlace());
+
+ framework::Tensor ar;
+ framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar);
+
+ framework::Tensor as;
+ framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as);
+
+ framework::Tensor aa;
+ framework::TensorFromVector(angles, ctx.device_context(), &aa);
+
+ framework::Tensor sd;
+ framework::TensorFromVector(stride, ctx.device_context(), &sd);
+
+ GenRAnchors<<>>(anchors->data(),
+ ar.data(),
+ aspect_ratios.size(),
+ as.data(),
+ anchor_sizes.size(),
+ aa.data(),
+ angles.size(),
+ sd.data(),
+ stride.size(),
+ height,
+ width,
+ offset);
+
+ framework::Tensor v;
+ framework::TensorFromVector(variances, ctx.device_context(), &v);
+ grid = (box_num * 5 + block - 1) / block;
+ SetVariance<<>>(
+ vars->data(), v.data(), variances.size(), box_num * 5);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(rotated_anchor_generator,
+ ops::RotatedAnchorGeneratorOpCUDAKernel,
+ ops::RotatedAnchorGeneratorOpCUDAKernel);
diff --git a/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.h b/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..81239d1f97303ca946efecb5234c4c09d31ae2c0
--- /dev/null
+++ b/PaddleCV/rrpn/models/ext_op/src/rotated_anchor_generator_op.h
@@ -0,0 +1,111 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include "paddle/fluid/framework/op_registry.h"
+//#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/platform/transform.h"
+
+namespace paddle {
+namespace operators {
+
+template
+class RotatedAnchorGeneratorOpKernel : public framework::OpKernel {
+public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* input = ctx.Input("Input");
+ auto* anchors = ctx.Output("Anchors");
+ auto* vars = ctx.Output("Variances");
+
+ auto anchor_sizes = ctx.Attr>("anchor_sizes");
+ auto aspect_ratios = ctx.Attr>("aspect_ratios");
+ auto angles = ctx.Attr>("angles");
+ auto stride = ctx.Attr>("stride");
+ auto variances = ctx.Attr>("variances");
+
+ T offset = static_cast(ctx.Attr("offset"));
+
+ auto feature_width = input->dims()[3];
+ auto feature_height = input->dims()[2];
+
+ T stride_width, stride_height;
+ stride_width = stride[0];
+ stride_height = stride[1];
+
+ int num_anchors =
+ aspect_ratios.size() * anchor_sizes.size() * angles.size();
+
+ anchors->mutable_data(ctx.GetPlace());
+ vars->mutable_data(ctx.GetPlace());
+
+ auto e_anchors = framework::EigenTensor