diff --git a/configs/det/det_r50_drrg_ctw.yml b/configs/det/det_r50_drrg_ctw.yml
new file mode 100755
index 0000000000000000000000000000000000000000..f67c926f3a8294a41df0751357061c69a895549e
--- /dev/null
+++ b/configs/det/det_r50_drrg_ctw.yml
@@ -0,0 +1,133 @@
+Global:
+ use_gpu: true
+ epoch_num: 1200
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/det_r50_drrg_ctw/
+ save_epoch_step: 100
+ # evaluation is run every 1260 iterations
+ eval_batch_step: [37800, 1260]
+ cal_metric_during_train: False
+ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_en/img_10.jpg
+ save_res_path: ./output/det_drrg/predicts_drrg.txt
+
+
+Architecture:
+ model_type: det
+ algorithm: DRRG
+ Transform:
+ Backbone:
+ name: ResNet_vd
+ layers: 50
+ Neck:
+ name: FPN_UNet
+ in_channels: [256, 512, 1024, 2048]
+ out_channels: 32
+ Head:
+ name: DRRGHead
+ in_channels: 32
+ text_region_thr: 0.3
+ center_region_thr: 0.4
+Loss:
+ name: DRRGLoss
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ lr:
+ name: DecayLearningRate
+ learning_rate: 0.028
+ epochs: 1200
+ factor: 0.9
+ end_lr: 0.0000001
+ weight_decay: 0.0001
+
+PostProcess:
+ name: DRRGPostprocess
+ link_thr: 0.8
+
+Metric:
+ name: DetFCEMetric
+ main_indicator: hmean
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ctw1500/imgs/
+ label_file_list:
+ - ./train_data/ctw1500/imgs/training.txt
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ ignore_orientation: True
+ - DetLabelEncode: # Class handling label
+ - ColorJitter:
+ brightness: 0.12549019607843137
+ saturation: 0.5
+ - RandomScaling:
+ - RandomCropFlip:
+ crop_ratio: 0.5
+ - RandomCropPolyInstances:
+ crop_ratio: 0.8
+ min_side_ratio: 0.3
+ - RandomRotatePolyInstances:
+ rotate_ratio: 0.5
+ max_angle: 60
+ pad_with_fixed_color: False
+ - SquareResizePad:
+ target_size: 800
+ pad_ratio: 0.6
+ - IaaAugment:
+ augmenter_args:
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
+ - DRRGTargets:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
+ 'gt_top_height_map', 'gt_bot_height_map', 'gt_sin_map',
+ 'gt_cos_map', 'gt_comp_attribs'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ drop_last: False
+ batch_size_per_card: 4
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ctw1500/imgs/
+ label_file_list:
+ - ./train_data/ctw1500/imgs/test.txt
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ ignore_orientation: True
+ - DetLabelEncode: # Class handling label
+ - DetResizeForTest:
+ limit_type: 'min'
+ limit_side_len: 640
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: 'hwc'
+ - Pad:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1 # must be 1
+ num_workers: 2
\ No newline at end of file
diff --git a/doc/doc_ch/algorithm_det_drrg.md b/doc/doc_ch/algorithm_det_drrg.md
new file mode 100644
index 0000000000000000000000000000000000000000..d89a16ae68b7024238a3982a342ef39764da9d16
--- /dev/null
+++ b/doc/doc_ch/algorithm_det_drrg.md
@@ -0,0 +1,78 @@
+# DRRG
+
+- [1. 算法简介](#1-算法简介)
+- [2. 环境配置](#2-环境配置)
+- [3. 模型训练、评估、预测](#3-模型训练评估预测)
+- [4. 推理部署](#4-推理部署)
+ - [4.1 Python推理](#41-python推理)
+ - [4.2 C++推理](#42-c推理)
+ - [4.3 Serving服务化部署](#43-serving服务化部署)
+ - [4.4 更多推理部署](#44-更多推理部署)
+- [5. FAQ](#5-faq)
+- [引用](#引用)
+
+
+## 1. 算法简介
+
+论文信息:
+> [Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection](https://arxiv.org/abs/2003.07493)
+> Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng
+> CVPR, 2020
+
+在CTW1500文本检测公开数据集上,算法复现效果如下:
+
+| 模型 |骨干网络|配置文件|precision|recall|Hmean|下载链接|
+|-----| --- | --- | --- | --- | --- | --- |
+| DRRG | ResNet50_vd | [configs/det/det_r50_drrg_ctw.yml](../../configs/det/det_r50_drrg_ctw.yml)| 89.92%|80.91%|85.18%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+上述DRRG模型使用CTW1500文本检测公开数据集训练得到,数据集下载可参考 [ocr_datasets](./dataset/ocr_datasets.md)。
+
+数据下载完成后,请参考[文本检测训练教程](./detection.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+
+由于模型前向运行时需要多次转换为Numpy数据进行运算,因此DRRG的动态图转静态图暂未支持。
+
+
+### 4.2 C++推理
+
+暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂未支持
+
+
+### 4.4 更多推理部署
+
+暂未支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@inproceedings{zhang2020deep,
+ title={Deep relational reasoning graph network for arbitrary shape text detection},
+ author={Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={9699--9708},
+ year={2020}
+}
+```
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index e9bcc275d1a7157628188c337a312a49408d207b..2dfece9da6b3bad7ca9b9ff8dd95cbb20a9f2a87 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -29,6 +29,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
- [x] [SAST](./algorithm_det_sast.md)
- [x] [PSENet](./algorithm_det_psenet.md)
- [x] [FCENet](./algorithm_det_fcenet.md)
+- [x] [DRRG](./algorithm_det_drrg.md)
在ICDAR2015文本检测公开数据集上,算法效果如下:
@@ -54,6 +55,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|模型|骨干网络|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- |
|FCE|ResNet50_dcn|88.39%|82.18%|85.27%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar)|
+|DRRG|ResNet50_vd|89.92%|80.91%|85.18%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:
* [百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
diff --git a/doc/doc_en/algorithm_det_drrg_en.md b/doc/doc_en/algorithm_det_drrg_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..2bb7b5703dab89526345e3dcbbb55d6c90ed1c0c
--- /dev/null
+++ b/doc/doc_en/algorithm_det_drrg_en.md
@@ -0,0 +1,79 @@
+# DRRG
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection](https://arxiv.org/abs/2003.07493)
+> Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng
+> CVPR, 2020
+
+On the CTW1500 dataset, the text detection result is as follows:
+
+|Model|Backbone|Configuration|Precision|Recall|Hmean|Download|
+| --- | --- | --- | --- | --- | --- | --- |
+| DRRG | ResNet50_vd | [configs/det/det_r50_drrg_ctw.yml](../../configs/det/det_r50_drrg_ctw.yml)| 89.92%|80.91%|85.18%|[trained model](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
+
+
+## 2. Environment
+Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md).
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+The above DRRG model is trained using the CTW1500 text detection public dataset. For the download of the dataset, please refer to [ocr_datasets](./dataset/ocr_datasets_en.md).
+
+After the data download is complete, please refer to [Text Detection Training Tutorial](./detection_en.md) for training. PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models.
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+
+Since the model needs to be converted to Numpy data for many times in the forward, DRRG dynamic graph to static graph is not supported.
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@inproceedings{zhang2020deep,
+ title={Deep relational reasoning graph network for arbitrary shape text detection},
+ author={Zhang, Shi-Xue and Zhu, Xiaobin and Hou, Jie-Bo and Liu, Chang and Yang, Chun and Wang, Hongfa and Yin, Xu-Cheng},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={9699--9708},
+ year={2020}
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 90449e1729fcff898f27641d3f777c8f002f6a97..bbab2638418c32f871d73356f6fff17f9b98f685 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -27,6 +27,7 @@ Supported text detection algorithms (Click the link to get the tutorial):
- [x] [SAST](./algorithm_det_sast_en.md)
- [x] [PSENet](./algorithm_det_psenet_en.md)
- [x] [FCENet](./algorithm_det_fcenet_en.md)
+- [x] [DRRG](./algorithm_det_drrg_en.md)
On the ICDAR2015 dataset, the text detection result is as follows:
@@ -52,6 +53,7 @@ On CTW1500 dataset, the text detection result is as follows:
|Model|Backbone|Precision|Recall|Hmean| Download link|
| --- | --- | --- | --- | --- |---|
|FCE|ResNet50_dcn|88.39%|82.18%|85.27%| [trained model](https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar) |
+|DRRG|ResNet50_vd|89.92%|80.91%|85.18%|[trained model](https://paddleocr.bj.bcebos.com/contribution/det_r50_drrg_ctw.tar)|
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from:
* [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index db0a489d5ff4cce5a51d8a2347a595efec61a1db..93d97446d44070b9c10064fbe10b0b5e05628a6a 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -45,6 +45,7 @@ from .vqa import *
from .fce_aug import *
from .fce_targets import FCENetTargets
from .ct_process import *
+from .drrg_targets import DRRGTargets
def transform(data, ops=None):
diff --git a/ppocr/data/imaug/drrg_targets.py b/ppocr/data/imaug/drrg_targets.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56e878b837328ef2efde40b96b5571dffbb4791
--- /dev/null
+++ b/ppocr/data/imaug/drrg_targets.py
@@ -0,0 +1,696 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py
+"""
+
+import cv2
+import numpy as np
+from lanms import merge_quadrangle_n9 as la_nms
+from numpy.linalg import norm
+
+
+class DRRGTargets(object):
+ def __init__(self,
+ orientation_thr=2.0,
+ resample_step=8.0,
+ num_min_comps=9,
+ num_max_comps=600,
+ min_width=8.0,
+ max_width=24.0,
+ center_region_shrink_ratio=0.3,
+ comp_shrink_ratio=1.0,
+ comp_w_h_ratio=0.3,
+ text_comp_nms_thr=0.25,
+ min_rand_half_height=8.0,
+ max_rand_half_height=24.0,
+ jitter_level=0.2,
+ **kwargs):
+
+ super().__init__()
+ self.orientation_thr = orientation_thr
+ self.resample_step = resample_step
+ self.num_max_comps = num_max_comps
+ self.num_min_comps = num_min_comps
+ self.min_width = min_width
+ self.max_width = max_width
+ self.center_region_shrink_ratio = center_region_shrink_ratio
+ self.comp_shrink_ratio = comp_shrink_ratio
+ self.comp_w_h_ratio = comp_w_h_ratio
+ self.text_comp_nms_thr = text_comp_nms_thr
+ self.min_rand_half_height = min_rand_half_height
+ self.max_rand_half_height = max_rand_half_height
+ self.jitter_level = jitter_level
+ self.eps = 1e-8
+
+ def vector_angle(self, vec1, vec2):
+ if vec1.ndim > 1:
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps).reshape((-1, 1))
+ else:
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps)
+ if vec2.ndim > 1:
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
+ else:
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
+ return np.arccos(
+ np.clip(
+ np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+ def vector_slope(self, vec):
+ assert len(vec) == 2
+ return abs(vec[1] / (vec[0] + self.eps))
+
+ def vector_sin(self, vec):
+ assert len(vec) == 2
+ return vec[1] / (norm(vec) + self.eps)
+
+ def vector_cos(self, vec):
+ assert len(vec) == 2
+ return vec[0] / (norm(vec) + self.eps)
+
+ def find_head_tail(self, points, orientation_thr):
+
+ assert points.ndim == 2
+ assert points.shape[0] >= 4
+ assert points.shape[1] == 2
+ assert isinstance(orientation_thr, float)
+
+ if len(points) > 4:
+ pad_points = np.vstack([points, points[0]])
+ edge_vec = pad_points[1:] - pad_points[:-1]
+
+ theta_sum = []
+ adjacent_vec_theta = []
+ for i, edge_vec1 in enumerate(edge_vec):
+ adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+ adjacent_edge_vec = edge_vec[adjacent_ind]
+ temp_theta_sum = np.sum(
+ self.vector_angle(edge_vec1, adjacent_edge_vec))
+ temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
+ adjacent_edge_vec[1])
+ theta_sum.append(temp_theta_sum)
+ adjacent_vec_theta.append(temp_adjacent_theta)
+ theta_sum_score = np.array(theta_sum) / np.pi
+ adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+ poly_center = np.mean(points, axis=0)
+ edge_dist = np.maximum(
+ norm(
+ pad_points[1:] - poly_center, axis=-1),
+ norm(
+ pad_points[:-1] - poly_center, axis=-1))
+ dist_score = edge_dist / (np.max(edge_dist) + self.eps)
+ position_score = np.zeros(len(edge_vec))
+ score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+ score += 0.35 * dist_score
+ if len(points) % 2 == 0:
+ position_score[(len(score) // 2 - 1)] += 1
+ position_score[-1] += 1
+ score += 0.1 * position_score
+ pad_score = np.concatenate([score, score])
+ score_matrix = np.zeros((len(score), len(score) - 3))
+ x = np.arange(len(score) - 3) / float(len(score) - 4)
+ gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
+ (x - 0.5) / 0.5, 2.) / 2)
+ gaussian = gaussian / np.max(gaussian)
+ for i in range(len(score)):
+ score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
+ score) - 1)] * gaussian * 0.3
+
+ head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
+ score_matrix.shape)
+ tail_start = (head_start + tail_increment + 2) % len(points)
+ head_end = (head_start + 1) % len(points)
+ tail_end = (tail_start + 1) % len(points)
+
+ if head_end > tail_end:
+ head_start, tail_start = tail_start, head_start
+ head_end, tail_end = tail_end, head_end
+ head_inds = [head_start, head_end]
+ tail_inds = [tail_start, tail_end]
+ else:
+ if self.vector_slope(points[1] - points[0]) + self.vector_slope(
+ points[3] - points[2]) < self.vector_slope(points[
+ 2] - points[1]) + self.vector_slope(points[0] - points[
+ 3]):
+ horizontal_edge_inds = [[0, 1], [2, 3]]
+ vertical_edge_inds = [[3, 0], [1, 2]]
+ else:
+ horizontal_edge_inds = [[3, 0], [1, 2]]
+ vertical_edge_inds = [[0, 1], [2, 3]]
+
+ vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
+ vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
+ 0]] - points[vertical_edge_inds[1][1]])
+ horizontal_len_sum = norm(points[horizontal_edge_inds[0][
+ 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
+ horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
+ [1]])
+
+ if vertical_len_sum > horizontal_len_sum * orientation_thr:
+ head_inds = horizontal_edge_inds[0]
+ tail_inds = horizontal_edge_inds[1]
+ else:
+ head_inds = vertical_edge_inds[0]
+ tail_inds = vertical_edge_inds[1]
+
+ return head_inds, tail_inds
+
+ def reorder_poly_edge(self, points):
+
+ assert points.ndim == 2
+ assert points.shape[0] >= 4
+ assert points.shape[1] == 2
+
+ head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
+ head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+ pad_points = np.vstack([points, points])
+ if tail_inds[1] < 1:
+ tail_inds[1] = len(points)
+ sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+ sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+ sideline_mean_shift = np.mean(
+ sideline1, axis=0) - np.mean(
+ sideline2, axis=0)
+
+ if sideline_mean_shift[1] > 0:
+ top_sideline, bot_sideline = sideline2, sideline1
+ else:
+ top_sideline, bot_sideline = sideline1, sideline2
+
+ return head_edge, tail_edge, top_sideline, bot_sideline
+
+ def cal_curve_length(self, line):
+
+ assert line.ndim == 2
+ assert len(line) >= 2
+
+ edges_length = np.sqrt((line[1:, 0] - line[:-1, 0])**2 + (line[
+ 1:, 1] - line[:-1, 1])**2)
+ total_length = np.sum(edges_length)
+ return edges_length, total_length
+
+ def resample_line(self, line, n):
+
+ assert line.ndim == 2
+ assert line.shape[0] >= 2
+ assert line.shape[1] == 2
+ assert isinstance(n, int)
+ assert n > 2
+
+ edges_length, total_length = self.cal_curve_length(line)
+ t_org = np.insert(np.cumsum(edges_length), 0, 0)
+ unit_t = total_length / (n - 1)
+ t_equidistant = np.arange(1, n - 1, dtype=np.float32) * unit_t
+ edge_ind = 0
+ points = [line[0]]
+ for t in t_equidistant:
+ while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
+ edge_ind += 1
+ t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
+ weight = np.array(
+ [t_r - t, t - t_l], dtype=np.float32) / (t_r - t_l + self.eps)
+ p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
+ points.append(p_coords)
+ points.append(line[-1])
+ resampled_line = np.vstack(points)
+
+ return resampled_line
+
+ def resample_sidelines(self, sideline1, sideline2, resample_step):
+
+ assert sideline1.ndim == sideline2.ndim == 2
+ assert sideline1.shape[1] == sideline2.shape[1] == 2
+ assert sideline1.shape[0] >= 2
+ assert sideline2.shape[0] >= 2
+ assert isinstance(resample_step, float)
+
+ _, length1 = self.cal_curve_length(sideline1)
+ _, length2 = self.cal_curve_length(sideline2)
+
+ avg_length = (length1 + length2) / 2
+ resample_point_num = max(int(float(avg_length) / resample_step) + 1, 3)
+
+ resampled_line1 = self.resample_line(sideline1, resample_point_num)
+ resampled_line2 = self.resample_line(sideline2, resample_point_num)
+
+ return resampled_line1, resampled_line2
+
+ def dist_point2line(self, point, line):
+
+ assert isinstance(line, tuple)
+ point1, point2 = line
+ d = abs(np.cross(point2 - point1, point - point1)) / (
+ norm(point2 - point1) + 1e-8)
+ return d
+
+ def draw_center_region_maps(self, top_line, bot_line, center_line,
+ center_region_mask, top_height_map,
+ bot_height_map, sin_map, cos_map,
+ region_shrink_ratio):
+
+ assert top_line.shape == bot_line.shape == center_line.shape
+ assert (center_region_mask.shape == top_height_map.shape ==
+ bot_height_map.shape == sin_map.shape == cos_map.shape)
+ assert isinstance(region_shrink_ratio, float)
+
+ h, w = center_region_mask.shape
+ for i in range(0, len(center_line) - 1):
+
+ top_mid_point = (top_line[i] + top_line[i + 1]) / 2
+ bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
+
+ sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
+ cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
+
+ tl = center_line[i] + (top_line[i] - center_line[i]
+ ) * region_shrink_ratio
+ tr = center_line[i + 1] + (top_line[i + 1] - center_line[i + 1]
+ ) * region_shrink_ratio
+ br = center_line[i + 1] + (bot_line[i + 1] - center_line[i + 1]
+ ) * region_shrink_ratio
+ bl = center_line[i] + (bot_line[i] - center_line[i]
+ ) * region_shrink_ratio
+ current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
+
+ cv2.fillPoly(center_region_mask, [current_center_box], color=1)
+ cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
+ cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
+
+ current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
+ w - 1)
+ current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
+ h - 1)
+ min_coord = np.min(current_center_box, axis=0).astype(np.int32)
+ max_coord = np.max(current_center_box, axis=0).astype(np.int32)
+ current_center_box = current_center_box - min_coord
+ box_sz = (max_coord - min_coord + 1)
+
+ center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
+ cv2.fillPoly(center_box_mask, [current_center_box], color=1)
+
+ inds = np.argwhere(center_box_mask > 0)
+ inds = inds + (min_coord[1], min_coord[0])
+ inds_xy = np.fliplr(inds)
+ top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
+ inds_xy, (top_line[i], top_line[i + 1]))
+ bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
+ inds_xy, (bot_line[i], bot_line[i + 1]))
+
+ def generate_center_mask_attrib_maps(self, img_size, text_polys):
+
+ assert isinstance(img_size, tuple)
+
+ h, w = img_size
+
+ center_lines = []
+ center_region_mask = np.zeros((h, w), np.uint8)
+ top_height_map = np.zeros((h, w), dtype=np.float32)
+ bot_height_map = np.zeros((h, w), dtype=np.float32)
+ sin_map = np.zeros((h, w), dtype=np.float32)
+ cos_map = np.zeros((h, w), dtype=np.float32)
+
+ for poly in text_polys:
+ polygon_points = poly
+ _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
+ resampled_top_line, resampled_bot_line = self.resample_sidelines(
+ top_line, bot_line, self.resample_step)
+ resampled_bot_line = resampled_bot_line[::-1]
+ center_line = (resampled_top_line + resampled_bot_line) / 2
+
+ if self.vector_slope(center_line[-1] - center_line[0]) > 2:
+ if (center_line[-1] - center_line[0])[1] < 0:
+ center_line = center_line[::-1]
+ resampled_top_line = resampled_top_line[::-1]
+ resampled_bot_line = resampled_bot_line[::-1]
+ else:
+ if (center_line[-1] - center_line[0])[0] < 0:
+ center_line = center_line[::-1]
+ resampled_top_line = resampled_top_line[::-1]
+ resampled_bot_line = resampled_bot_line[::-1]
+
+ line_head_shrink_len = np.clip(
+ (norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
+ self.min_width, self.max_width) / 2
+ line_tail_shrink_len = np.clip(
+ (norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
+ self.min_width, self.max_width) / 2
+ num_head_shrink = int(line_head_shrink_len // self.resample_step)
+ num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
+ if len(center_line) > num_head_shrink + num_tail_shrink + 2:
+ center_line = center_line[num_head_shrink:len(center_line) -
+ num_tail_shrink]
+ resampled_top_line = resampled_top_line[num_head_shrink:len(
+ resampled_top_line) - num_tail_shrink]
+ resampled_bot_line = resampled_bot_line[num_head_shrink:len(
+ resampled_bot_line) - num_tail_shrink]
+ center_lines.append(center_line.astype(np.int32))
+
+ self.draw_center_region_maps(
+ resampled_top_line, resampled_bot_line, center_line,
+ center_region_mask, top_height_map, bot_height_map, sin_map,
+ cos_map, self.center_region_shrink_ratio)
+
+ return (center_lines, center_region_mask, top_height_map,
+ bot_height_map, sin_map, cos_map)
+
+ def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
+
+ assert isinstance(num_rand_comps, int)
+ assert num_rand_comps > 0
+ assert center_sample_mask.ndim == 2
+
+ h, w = center_sample_mask.shape
+
+ max_rand_half_height = self.max_rand_half_height
+ min_rand_half_height = self.min_rand_half_height
+ max_rand_height = max_rand_half_height * 2
+ max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
+ self.min_width, self.max_width)
+ margin = int(
+ np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
+
+ if 2 * margin + 1 > min(h, w):
+
+ assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
+ max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
+ min_rand_half_height = max(max_rand_half_height / 4,
+ self.min_width / 2)
+
+ max_rand_height = max_rand_half_height * 2
+ max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
+ self.min_width, self.max_width)
+ margin = int(
+ np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
+
+ inner_center_sample_mask = np.zeros_like(center_sample_mask)
+ inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
+ center_sample_mask[margin:h - margin, margin:w - margin]
+ kernel_size = int(np.clip(max_rand_half_height, 7, 21))
+ inner_center_sample_mask = cv2.erode(
+ inner_center_sample_mask,
+ np.ones((kernel_size, kernel_size), np.uint8))
+
+ center_candidates = np.argwhere(inner_center_sample_mask > 0)
+ num_center_candidates = len(center_candidates)
+ sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
+ rand_centers = center_candidates[sample_inds]
+
+ rand_top_height = np.random.randint(
+ min_rand_half_height,
+ max_rand_half_height,
+ size=(len(rand_centers), 1))
+ rand_bot_height = np.random.randint(
+ min_rand_half_height,
+ max_rand_half_height,
+ size=(len(rand_centers), 1))
+
+ rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
+ rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
+ scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
+ rand_cos = rand_cos * scale
+ rand_sin = rand_sin * scale
+
+ height = (rand_top_height + rand_bot_height)
+ width = np.clip(height * self.comp_w_h_ratio, self.min_width,
+ self.max_width)
+
+ rand_comp_attribs = np.hstack([
+ rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
+ np.zeros_like(rand_sin)
+ ]).astype(np.float32)
+
+ return rand_comp_attribs
+
+ def jitter_comp_attribs(self, comp_attribs, jitter_level):
+ """Jitter text components attributes.
+
+ Args:
+ comp_attribs (ndarray): The text component attributes.
+ jitter_level (float): The jitter level of text components
+ attributes.
+
+ Returns:
+ jittered_comp_attribs (ndarray): The jittered text component
+ attributes (x, y, h, w, cos, sin, comp_label).
+ """
+
+ assert comp_attribs.shape[1] == 7
+ assert comp_attribs.shape[0] > 0
+ assert isinstance(jitter_level, float)
+
+ x = comp_attribs[:, 0].reshape((-1, 1))
+ y = comp_attribs[:, 1].reshape((-1, 1))
+ h = comp_attribs[:, 2].reshape((-1, 1))
+ w = comp_attribs[:, 3].reshape((-1, 1))
+ cos = comp_attribs[:, 4].reshape((-1, 1))
+ sin = comp_attribs[:, 5].reshape((-1, 1))
+ comp_labels = comp_attribs[:, 6].reshape((-1, 1))
+
+ x += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
+ h * np.abs(cos) + w * np.abs(sin)) * jitter_level
+ y += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
+ h * np.abs(sin) + w * np.abs(cos)) * jitter_level
+
+ h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
+ ) * h * jitter_level
+ w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
+ ) * w * jitter_level
+
+ cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
+ ) * 2 * jitter_level
+ sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
+ ) * 2 * jitter_level
+
+ scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
+ cos = cos * scale
+ sin = sin * scale
+
+ jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
+
+ return jittered_comp_attribs
+
+ def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
+ top_height_map, bot_height_map, sin_map, cos_map):
+ """Generate text component attributes.
+
+ Args:
+ center_lines (list[ndarray]): The list of text center lines .
+ text_mask (ndarray): The text region mask.
+ center_region_mask (ndarray): The text center region mask.
+ top_height_map (ndarray): The map on which the distance from points
+ to top side lines will be drawn for each pixel in text center
+ regions.
+ bot_height_map (ndarray): The map on which the distance from points
+ to bottom side lines will be drawn for each pixel in text
+ center regions.
+ sin_map (ndarray): The sin(theta) map where theta is the angle
+ between vector (top point - bottom point) and vector (1, 0).
+ cos_map (ndarray): The cos(theta) map where theta is the angle
+ between vector (top point - bottom point) and vector (1, 0).
+
+ Returns:
+ pad_comp_attribs (ndarray): The padded text component attributes
+ of a fixed size.
+ """
+
+ assert isinstance(center_lines, list)
+ assert (
+ text_mask.shape == center_region_mask.shape == top_height_map.shape
+ == bot_height_map.shape == sin_map.shape == cos_map.shape)
+
+ center_lines_mask = np.zeros_like(center_region_mask)
+ cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
+ center_lines_mask = center_lines_mask * center_region_mask
+ comp_centers = np.argwhere(center_lines_mask > 0)
+
+ y = comp_centers[:, 0]
+ x = comp_centers[:, 1]
+
+ top_height = top_height_map[y, x].reshape(
+ (-1, 1)) * self.comp_shrink_ratio
+ bot_height = bot_height_map[y, x].reshape(
+ (-1, 1)) * self.comp_shrink_ratio
+ sin = sin_map[y, x].reshape((-1, 1))
+ cos = cos_map[y, x].reshape((-1, 1))
+
+ top_mid_points = comp_centers + np.hstack(
+ [top_height * sin, top_height * cos])
+ bot_mid_points = comp_centers - np.hstack(
+ [bot_height * sin, bot_height * cos])
+
+ width = (top_height + bot_height) * self.comp_w_h_ratio
+ width = np.clip(width, self.min_width, self.max_width)
+ r = width / 2
+
+ tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
+ tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
+ br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
+ bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
+ text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
+
+ score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
+ text_comps = np.hstack([text_comps, score])
+ text_comps = la_nms(text_comps, self.text_comp_nms_thr)
+
+ if text_comps.shape[0] >= 1:
+ img_h, img_w = center_region_mask.shape
+ text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
+ text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
+
+ comp_centers = np.mean(
+ text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
+ x = comp_centers[:, 0]
+ y = comp_centers[:, 1]
+
+ height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
+ (-1, 1))
+ width = np.clip(height * self.comp_w_h_ratio, self.min_width,
+ self.max_width)
+
+ cos = cos_map[y, x].reshape((-1, 1))
+ sin = sin_map[y, x].reshape((-1, 1))
+
+ _, comp_label_mask = cv2.connectedComponents(
+ center_region_mask, connectivity=8)
+ comp_labels = comp_label_mask[y, x].reshape(
+ (-1, 1)).astype(np.float32)
+
+ x = x.reshape((-1, 1)).astype(np.float32)
+ y = y.reshape((-1, 1)).astype(np.float32)
+ comp_attribs = np.hstack(
+ [x, y, height, width, cos, sin, comp_labels])
+ comp_attribs = self.jitter_comp_attribs(comp_attribs,
+ self.jitter_level)
+
+ if comp_attribs.shape[0] < self.num_min_comps:
+ num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
+ rand_comp_attribs = self.generate_rand_comp_attribs(
+ num_rand_comps, 1 - text_mask)
+ comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
+ else:
+ comp_attribs = self.generate_rand_comp_attribs(self.num_min_comps,
+ 1 - text_mask)
+
+ num_comps = (np.ones(
+ (comp_attribs.shape[0], 1),
+ dtype=np.float32) * comp_attribs.shape[0])
+ comp_attribs = np.hstack([num_comps, comp_attribs])
+
+ if comp_attribs.shape[0] > self.num_max_comps:
+ comp_attribs = comp_attribs[:self.num_max_comps, :]
+ comp_attribs[:, 0] = self.num_max_comps
+
+ pad_comp_attribs = np.zeros(
+ (self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
+ pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
+
+ return pad_comp_attribs
+
+ def generate_text_region_mask(self, img_size, text_polys):
+ """Generate text center region mask and geometry attribute maps.
+
+ Args:
+ img_size (tuple): The image size (height, width).
+ text_polys (list[list[ndarray]]): The list of text polygons.
+
+ Returns:
+ text_region_mask (ndarray): The text region mask.
+ """
+
+ assert isinstance(img_size, tuple)
+
+ h, w = img_size
+ text_region_mask = np.zeros((h, w), dtype=np.uint8)
+
+ for poly in text_polys:
+ polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
+ cv2.fillPoly(text_region_mask, polygon, 1)
+
+ return text_region_mask
+
+ def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
+ """Generate effective mask by setting the ineffective regions to 0 and
+ effective regions to 1.
+
+ Args:
+ mask_size (tuple): The mask size.
+ polygons_ignore (list[[ndarray]]: The list of ignored text
+ polygons.
+
+ Returns:
+ mask (ndarray): The effective mask of (height, width).
+ """
+ mask = np.ones(mask_size, dtype=np.uint8)
+
+ for poly in polygons_ignore:
+ instance = poly.astype(np.int32).reshape(1, -1, 2)
+ cv2.fillPoly(mask, instance, 0)
+
+ return mask
+
+ def generate_targets(self, data):
+ """Generate the gt targets for DRRG.
+
+ Args:
+ data (dict): The input result dictionary.
+
+ Returns:
+ data (dict): The output result dictionary.
+ """
+
+ assert isinstance(data, dict)
+
+ image = data['image']
+ polygons = data['polys']
+ ignore_tags = data['ignore_tags']
+ h, w, _ = image.shape
+
+ polygon_masks = []
+ polygon_masks_ignore = []
+ for tag, polygon in zip(ignore_tags, polygons):
+ if tag is True:
+ polygon_masks_ignore.append(polygon)
+ else:
+ polygon_masks.append(polygon)
+
+ gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
+ gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
+ (center_lines, gt_center_region_mask, gt_top_height_map,
+ gt_bot_height_map, gt_sin_map,
+ gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
+ polygon_masks)
+
+ gt_comp_attribs = self.generate_comp_attribs(
+ center_lines, gt_text_mask, gt_center_region_mask,
+ gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map)
+
+ mapping = {
+ 'gt_text_mask': gt_text_mask,
+ 'gt_center_region_mask': gt_center_region_mask,
+ 'gt_mask': gt_mask,
+ 'gt_top_height_map': gt_top_height_map,
+ 'gt_bot_height_map': gt_bot_height_map,
+ 'gt_sin_map': gt_sin_map,
+ 'gt_cos_map': gt_cos_map
+ }
+
+ data.update(mapping)
+ data['gt_comp_attribs'] = gt_comp_attribs
+ return data
+
+ def __call__(self, data):
+ data = self.generate_targets(data)
+ return data
diff --git a/ppocr/ext_op/__init__.py b/ppocr/ext_op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8307f3810bf56d34773d89c1049da3dabb1db7d2
--- /dev/null
+++ b/ppocr/ext_op/__init__.py
@@ -0,0 +1 @@
+from .roi_align_rotated.roi_align_rotated import RoIAlignRotated
diff --git a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc
new file mode 100644
index 0000000000000000000000000000000000000000..2de86c53730c58bc58b0b6bd5e0098435339d4f9
--- /dev/null
+++ b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc
@@ -0,0 +1,528 @@
+
+// This code is refer from:
+// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cpu/roi_align_rotated.cpp
+
+#include
+#include
+#include
+
+#include "paddle/extension.h"
+
+#define PADDLE_WITH_CUDA
+#define CHECK_INPUT_SAME(x1, x2) \
+ PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
+#define CHECK_INPUT_CPU(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
+
+template struct PreCalc {
+ int pos1;
+ int pos2;
+ int pos3;
+ int pos4;
+ T w1;
+ T w2;
+ T w3;
+ T w4;
+};
+
+template
+void pre_calc_for_bilinear_interpolate(
+ const int height, const int width, const int pooled_height,
+ const int pooled_width, const int iy_upper, const int ix_upper,
+ T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
+ int roi_bin_grid_h, int roi_bin_grid_w, T roi_center_h, T roi_center_w,
+ T cos_theta, T sin_theta, std::vector> &pre_calc) {
+ int pre_calc_index = 0;
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ for (int iy = 0; iy < iy_upper; iy++) {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < ix_upper; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ // In image space, (y, x) is the order for Right Handed System,
+ // and this is essentially multiplying the point by a rotation matrix
+ // to rotate it counterclockwise through angle theta.
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
+ // deal with: inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ PreCalc pc;
+ pc.pos1 = 0;
+ pc.pos2 = 0;
+ pc.pos3 = 0;
+ pc.pos4 = 0;
+ pc.w1 = 0;
+ pc.w2 = 0;
+ pc.w3 = 0;
+ pc.w4 = 0;
+ pre_calc[pre_calc_index] = pc;
+ pre_calc_index += 1;
+ continue;
+ }
+
+ if (y < 0) {
+ y = 0;
+ }
+ if (x < 0) {
+ x = 0;
+ }
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ // save weights and indices
+ PreCalc pc;
+ pc.pos1 = y_low * width + x_low;
+ pc.pos2 = y_low * width + x_high;
+ pc.pos3 = y_high * width + x_low;
+ pc.pos4 = y_high * width + x_high;
+ pc.w1 = w1;
+ pc.w2 = w2;
+ pc.w3 = w3;
+ pc.w4 = w4;
+ pre_calc[pre_calc_index] = pc;
+
+ pre_calc_index += 1;
+ }
+ }
+ }
+ }
+}
+
+template
+void roi_align_rotated_cpu_forward(const int nthreads, const T *input,
+ const T &spatial_scale, const bool aligned,
+ const bool clockwise, const int channels,
+ const int height, const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio, const T *rois,
+ T *output) {
+ int n_rois = nthreads / channels / pooled_width / pooled_height;
+ // (n, c, ph, pw) is an element in the pooled output
+ // can be parallelized using omp
+ // #pragma omp parallel for num_threads(32)
+ for (int n = 0; n < n_rois; n++) {
+ int index_n = n * channels * pooled_width * pooled_height;
+
+ const T *current_roi = rois + n * 6;
+ int roi_batch_ind = current_roi[0];
+
+ // Do not use rounding; this implementation detail is critical
+ T offset = aligned ? (T)0.5 : (T)0.0;
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
+ T roi_width = current_roi[3] * spatial_scale;
+ T roi_height = current_roi[4] * spatial_scale;
+ T theta = current_roi[5];
+ if (clockwise) {
+ theta = -theta; // If clockwise, the angle needs to be reversed.
+ }
+ T cos_theta = cos(theta);
+ T sin_theta = sin(theta);
+
+ if (aligned) {
+ assert(roi_width >= 0 && roi_height >= 0);
+ } else { // for backward-compatibility only
+ roi_width = std::max(roi_width, (T)1.);
+ roi_height = std::max(roi_height, (T)1.);
+ }
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceilf(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
+
+ // we want to precalculate indices and weights shared by all channels,
+ // this is the key point of optimization
+ std::vector> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
+ pooled_width * pooled_height);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ T roi_start_h = -roi_height / 2.0;
+ T roi_start_w = -roi_width / 2.0;
+
+ pre_calc_for_bilinear_interpolate(
+ height, width, pooled_height, pooled_width, roi_bin_grid_h,
+ roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
+ roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta,
+ sin_theta, pre_calc);
+
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * pooled_width * pooled_height;
+ const T *offset_input =
+ input + (roi_batch_ind * channels + c) * height * width;
+ int pre_calc_index = 0;
+
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ int index = index_n_c + ph * pooled_width + pw;
+
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ PreCalc pc = pre_calc[pre_calc_index];
+ output_val += pc.w1 * offset_input[pc.pos1] +
+ pc.w2 * offset_input[pc.pos2] +
+ pc.w3 * offset_input[pc.pos3] +
+ pc.w4 * offset_input[pc.pos4];
+
+ pre_calc_index += 1;
+ }
+ }
+ output_val /= count;
+
+ output[index] = output_val;
+ } // for pw
+ } // for ph
+ } // for c
+ } // for n
+}
+
+template
+void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
+ T &w1, T &w2, T &w3, T &w4, int &x_low,
+ int &x_high, int &y_low, int &y_high) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y < 0) {
+ y = 0;
+ }
+
+ if (x < 0) {
+ x = 0;
+ }
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = input[y_low * width + x_low];
+ // T v2 = input[y_low * width + x_high];
+ // T v3 = input[y_high * width + x_low];
+ // T v4 = input[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+
+template inline void add(T *address, const T &val) {
+ *address += val;
+}
+
+template
+void roi_align_rotated_cpu_backward(
+ const int nthreads,
+ // may not be contiguous. should index using n_stride, etc
+ const T *grad_output, const T &spatial_scale, const bool aligned,
+ const bool clockwise, const int channels, const int height, const int width,
+ const int pooled_height, const int pooled_width, const int sampling_ratio,
+ T *grad_input, const T *rois, const int n_stride, const int c_stride,
+ const int h_stride, const int w_stride) {
+ for (int index = 0; index < nthreads; index++) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T *current_roi = rois + n * 6;
+ int roi_batch_ind = current_roi[0];
+
+ // Do not use rounding; this implementation detail is critical
+ T offset = aligned ? (T)0.5 : (T)0.0;
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
+ T roi_width = current_roi[3] * spatial_scale;
+ T roi_height = current_roi[4] * spatial_scale;
+ T theta = current_roi[5];
+ if (clockwise) {
+ theta = -theta; // If clockwise, the angle needs to be reversed.
+ }
+ T cos_theta = cos(theta);
+ T sin_theta = sin(theta);
+
+ if (aligned) {
+ assert(roi_width >= 0 && roi_height >= 0);
+ } else { // for backward-compatibility only
+ roi_width = std::max(roi_width, (T)1.);
+ roi_height = std::max(roi_height, (T)1.);
+ }
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ T *offset_grad_input =
+ grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+ int output_offset = n * n_stride + c * c_stride;
+ const T *offset_grad_output = grad_output + output_offset;
+ const T grad_output_this_bin =
+ offset_grad_output[ph * h_stride + pw * w_stride];
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceilf(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ T roi_start_h = -roi_height / 2.0;
+ T roi_start_w = -roi_width / 2.0;
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
+ x_low, x_high, y_low, y_high);
+
+ T g1 = grad_output_this_bin * w1 / count;
+ T g2 = grad_output_this_bin * w2 / count;
+ T g3 = grad_output_this_bin * w3 / count;
+ T g4 = grad_output_this_bin * w4 / count;
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ // atomic add is not needed for now since it is single threaded
+ add(offset_grad_input + y_low * width + x_low, static_cast(g1));
+ add(offset_grad_input + y_low * width + x_high, static_cast(g2));
+ add(offset_grad_input + y_high * width + x_low, static_cast(g3));
+ add(offset_grad_input + y_high * width + x_high, static_cast(g4));
+ } // if
+ } // ix
+ } // iy
+ } // for
+} // ROIAlignRotatedBackward
+
+std::vector
+RoIAlignRotatedCPUForward(const paddle::Tensor &input,
+ const paddle::Tensor &rois, int aligned_height,
+ int aligned_width, float spatial_scale,
+ int sampling_ratio, bool aligned, bool clockwise) {
+ CHECK_INPUT_CPU(input);
+ CHECK_INPUT_CPU(rois);
+
+ auto num_rois = rois.shape()[0];
+
+ auto channels = input.shape()[1];
+ auto height = input.shape()[2];
+ auto width = input.shape()[3];
+
+ auto output =
+ paddle::empty({num_rois, channels, aligned_height, aligned_width},
+ input.type(), paddle::CPUPlace());
+ auto output_size = output.numel();
+
+ PD_DISPATCH_FLOATING_TYPES(
+ input.type(), "roi_align_rotated_cpu_forward", ([&] {
+ roi_align_rotated_cpu_forward(
+ output_size, input.data(),
+ static_cast(spatial_scale), aligned, clockwise, channels,
+ height, width, aligned_height, aligned_width, sampling_ratio,
+ rois.data(), output.data());
+ }));
+
+ return {output};
+}
+
+std::vector RoIAlignRotatedCPUBackward(
+ const paddle::Tensor &input, const paddle::Tensor &rois,
+ const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
+ float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
+
+ auto batch_size = input.shape()[0];
+ auto channels = input.shape()[1];
+ auto height = input.shape()[2];
+ auto width = input.shape()[3];
+
+ auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
+ input.type(), paddle::CPUPlace());
+
+ // get stride values to ensure indexing into gradients is correct.
+ int n_stride = grad_output.shape()[0];
+ int c_stride = grad_output.shape()[1];
+ int h_stride = grad_output.shape()[2];
+ int w_stride = grad_output.shape()[3];
+
+ PD_DISPATCH_FLOATING_TYPES(
+ grad_output.type(), "roi_align_rotated_cpu_backward", [&] {
+ roi_align_rotated_cpu_backward(
+ grad_output.numel(), grad_output.data(),
+ static_cast(spatial_scale), aligned, clockwise, channels,
+ height, width, aligned_height, aligned_width, sampling_ratio,
+ grad_input.data(), rois.data(), n_stride, c_stride,
+ h_stride, w_stride);
+ });
+ return {grad_input};
+}
+
+#ifdef PADDLE_WITH_CUDA
+std::vector
+RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
+ const paddle::Tensor &rois, int aligned_height,
+ int aligned_width, float spatial_scale,
+ int sampling_ratio, bool aligned, bool clockwise);
+#endif
+
+#ifdef PADDLE_WITH_CUDA
+std::vector RoIAlignRotatedCUDABackward(
+ const paddle::Tensor &input, const paddle::Tensor &rois,
+ const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
+ float spatial_scale, int sampling_ratio, bool aligned, bool clockwise);
+#endif
+
+std::vector
+RoIAlignRotatedForward(const paddle::Tensor &input, const paddle::Tensor &rois,
+ int aligned_height, int aligned_width,
+ float spatial_scale, int sampling_ratio, bool aligned,
+ bool clockwise) {
+ CHECK_INPUT_SAME(input, rois);
+ if (input.is_cpu()) {
+ return RoIAlignRotatedCPUForward(input, rois, aligned_height, aligned_width,
+ spatial_scale, sampling_ratio, aligned,
+ clockwise);
+#ifdef PADDLE_WITH_CUDA
+ } else if (input.is_gpu()) {
+ return RoIAlignRotatedCUDAForward(input, rois, aligned_height,
+ aligned_width, spatial_scale,
+ sampling_ratio, aligned, clockwise);
+#endif
+ } else {
+ PD_THROW("Unsupported device type for forward function of roi align "
+ "rotated operator.");
+ }
+}
+
+std::vector
+RoIAlignRotatedBackward(const paddle::Tensor &input, const paddle::Tensor &rois,
+ const paddle::Tensor &grad_output, int aligned_height,
+ int aligned_width, float spatial_scale,
+ int sampling_ratio, bool aligned, bool clockwise) {
+ CHECK_INPUT_SAME(input, rois);
+ if (input.is_cpu()) {
+ return RoIAlignRotatedCPUBackward(input, rois, grad_output, aligned_height,
+ aligned_width, spatial_scale,
+ sampling_ratio, aligned, clockwise);
+#ifdef PADDLE_WITH_CUDA
+ } else if (input.is_gpu()) {
+ return RoIAlignRotatedCUDABackward(input, rois, grad_output, aligned_height,
+ aligned_width, spatial_scale,
+ sampling_ratio, aligned, clockwise);
+#endif
+ } else {
+ PD_THROW("Unsupported device type for forward function of roi align "
+ "rotated operator.");
+ }
+}
+
+std::vector> InferShape(std::vector input_shape,
+ std::vector rois_shape) {
+ return {{rois_shape[0], input_shape[1], input_shape[2], input_shape[3]}};
+}
+
+std::vector>
+InferBackShape(std::vector input_shape,
+ std::vector rois_shape) {
+ return {input_shape};
+}
+
+std::vector InferDtype(paddle::DataType input_dtype,
+ paddle::DataType rois_dtype) {
+ return {input_dtype};
+}
+
+PD_BUILD_OP(roi_align_rotated)
+ .Inputs({"Input", "Rois"})
+ .Outputs({"Output"})
+ .Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
+ "sampling_ratio: int", "aligned: bool", "clockwise: bool"})
+ .SetKernelFn(PD_KERNEL(RoIAlignRotatedForward))
+ .SetInferShapeFn(PD_INFER_SHAPE(InferShape))
+ .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
+
+PD_BUILD_GRAD_OP(roi_align_rotated)
+ .Inputs({"Input", "Rois", paddle::Grad("Output")})
+ .Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
+ "sampling_ratio: int", "aligned: bool", "clockwise: bool"})
+ .Outputs({paddle::Grad("Input")})
+ .SetKernelFn(PD_KERNEL(RoIAlignRotatedBackward))
+ .SetInferShapeFn(PD_INFER_SHAPE(InferBackShape));
diff --git a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu
new file mode 100644
index 0000000000000000000000000000000000000000..17bd47dc08be732bdb228da9696ee2d163179c73
--- /dev/null
+++ b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu
@@ -0,0 +1,380 @@
+
+// This code is refer from:
+// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
+
+#include
+#include
+#include
+
+#include "paddle/extension.h"
+#include
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+#define THREADS_PER_BLOCK 512
+
+inline int GET_BLOCKS(const int N) {
+ int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+ int max_block_num = 4096;
+ return min(optimal_block_num, max_block_num);
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
+
+static __inline__ __device__ double atomicAdd(double *address, double val) {
+ unsigned long long int *address_as_ull = (unsigned long long int *)address;
+ unsigned long long int old = *address_as_ull, assumed;
+ if (val == 0.0)
+ return __longlong_as_double(old);
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(val + __longlong_as_double(assumed)));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
+#endif
+
+template
+__device__ T bilinear_interpolate(const T *input, const int height,
+ const int width, T y, T x,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width)
+ return 0;
+
+ if (y <= 0)
+ y = 0;
+ if (x <= 0)
+ x = 0;
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ // do bilinear interpolation
+ T v1 = input[y_low * width + x_low];
+ T v2 = input[y_low * width + x_high];
+ T v3 = input[y_high * width + x_low];
+ T v4 = input[y_high * width + x_high];
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ return val;
+}
+
+template
+__device__ void
+bilinear_interpolate_gradient(const int height, const int width, T y, T x,
+ T &w1, T &w2, T &w3, T &w4, int &x_low,
+ int &x_high, int &y_low, int &y_high,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y <= 0)
+ y = 0;
+ if (x <= 0)
+ x = 0;
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = input[y_low * width + x_low];
+ // T v2 = input[y_low * width + x_high];
+ // T v3 = input[y_high * width + x_low];
+ // T v4 = input[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+
+/*** Forward ***/
+template
+__global__ void roi_align_rotated_cuda_forward_kernel(
+ const int nthreads, const scalar_t *bottom_data,
+ const scalar_t *bottom_rois, const scalar_t spatial_scale,
+ const int sample_num, const bool aligned, const bool clockwise,
+ const int channels, const int height, const int width,
+ const int pooled_height, const int pooled_width, scalar_t *top_data) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
+ int roi_batch_ind = offset_bottom_rois[0];
+
+ // Do not using rounding; this implementation detail is critical
+ scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
+ scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
+ scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
+ scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
+ scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
+ // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
+ scalar_t theta = offset_bottom_rois[5];
+ if (clockwise) {
+ theta = -theta; // If clockwise, the angle needs to be reversed.
+ }
+ if (!aligned) { // for backward-compatibility only
+ // Force malformed ROIs to be 1x1
+ roi_width = max(roi_width, (scalar_t)1.);
+ roi_height = max(roi_height, (scalar_t)1.);
+ }
+ scalar_t bin_size_h = static_cast(roi_height) /
+ static_cast(pooled_height);
+ scalar_t bin_size_w =
+ static_cast(roi_width) / static_cast(pooled_width);
+
+ const scalar_t *offset_bottom_data =
+ bottom_data + (roi_batch_ind * channels + c) * height * width;
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sample_num > 0)
+ ? sample_num
+ : ceilf(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ scalar_t roi_start_h = -roi_height / 2.0;
+ scalar_t roi_start_w = -roi_width / 2.0;
+ scalar_t cosscalar_theta = cos(theta);
+ scalar_t sinscalar_theta = sin(theta);
+
+ // We do average (integral) pooling inside a bin
+ const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
+
+ scalar_t output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
+ const scalar_t yy =
+ roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const scalar_t xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta (counterclockwise) around the center and translate
+ scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
+ scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
+
+ scalar_t val = bilinear_interpolate(
+ offset_bottom_data, height, width, y, x, index);
+ output_val += val;
+ }
+ }
+ output_val /= count;
+
+ top_data[index] = output_val;
+ }
+}
+
+/*** Backward ***/
+template
+__global__ void roi_align_rotated_backward_cuda_kernel(
+ const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
+ const scalar_t spatial_scale, const int sample_num, const bool aligned,
+ const bool clockwise, const int channels, const int height, const int width,
+ const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
+ int roi_batch_ind = offset_bottom_rois[0];
+
+ // Do not round
+ scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
+ scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
+ scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
+ scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
+ scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
+ // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
+ scalar_t theta = offset_bottom_rois[5];
+ if (clockwise) {
+ theta = -theta; // If clockwise, the angle needs to be reversed.
+ }
+ if (!aligned) { // for backward-compatibility only
+ // Force malformed ROIs to be 1x1
+ roi_width = max(roi_width, (scalar_t)1.);
+ roi_height = max(roi_height, (scalar_t)1.);
+ }
+ scalar_t bin_size_h = static_cast(roi_height) /
+ static_cast(pooled_height);
+ scalar_t bin_size_w =
+ static_cast(roi_width) / static_cast(pooled_width);
+
+ scalar_t *offset_bottom_diff =
+ bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
+ const scalar_t *offset_top_diff = top_diff + top_offset;
+ const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sample_num > 0)
+ ? sample_num
+ : ceilf(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ scalar_t roi_start_h = -roi_height / 2.0;
+ scalar_t roi_start_w = -roi_width / 2.0;
+ scalar_t cosTheta = cos(theta);
+ scalar_t sinTheta = sin(theta);
+
+ // We do average (integral) pooling inside a bin
+ const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
+ const scalar_t yy =
+ roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const scalar_t xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
+ scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
+
+ scalar_t w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3,
+ w4, x_low, x_high, y_low,
+ y_high, index);
+
+ scalar_t g1 = top_diff_this_bin * w1 / count;
+ scalar_t g2 = top_diff_this_bin * w2 / count;
+ scalar_t g3 = top_diff_this_bin * w3 / count;
+ scalar_t g4 = top_diff_this_bin * w4 / count;
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
+ atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
+ atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
+ atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
+ } // if
+ } // ix
+ } // iy
+ } // CUDA_1D_KERNEL_LOOP
+} // RoIAlignBackward
+
+std::vector
+RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
+ const paddle::Tensor &rois, int aligned_height,
+ int aligned_width, float spatial_scale,
+ int sampling_ratio, bool aligned, bool clockwise) {
+
+ auto num_rois = rois.shape()[0];
+
+ auto channels = input.shape()[1];
+ auto height = input.shape()[2];
+ auto width = input.shape()[3];
+
+ auto output =
+ paddle::empty({num_rois, channels, aligned_height, aligned_width},
+ input.type(), paddle::GPUPlace());
+ auto output_size = output.numel();
+
+ PD_DISPATCH_FLOATING_TYPES(
+ input.type(), "roi_align_rotated_cuda_forward_kernel", ([&] {
+ roi_align_rotated_cuda_forward_kernel<
+ data_t><<>>(
+ output_size, input.data(), rois.data(),
+ static_cast(spatial_scale), sampling_ratio, aligned,
+ clockwise, channels, height, width, aligned_height, aligned_width,
+ output.data());
+ }));
+
+ return {output};
+}
+
+std::vector RoIAlignRotatedCUDABackward(
+ const paddle::Tensor &input, const paddle::Tensor &rois,
+ const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
+ float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
+
+ auto num_rois = rois.shape()[0];
+
+ auto batch_size = input.shape()[0];
+ auto channels = input.shape()[1];
+ auto height = input.shape()[2];
+ auto width = input.shape()[3];
+
+ auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
+ input.type(), paddle::GPUPlace());
+
+ const int output_size = num_rois * aligned_height * aligned_width * channels;
+
+ PD_DISPATCH_FLOATING_TYPES(
+ grad_output.type(), "roi_align_rotated_backward_cuda_kernel", ([&] {
+ roi_align_rotated_backward_cuda_kernel<
+ data_t><<>>(
+ output_size, grad_output.data(), rois.data(),
+ spatial_scale, sampling_ratio, aligned, clockwise, channels, height,
+ width, aligned_height, aligned_width, grad_input.data());
+ }));
+ return {grad_input};
+}
\ No newline at end of file
diff --git a/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcca285c75f9c68ff15409810edcec887eed2026
--- /dev/null
+++ b/ppocr/ext_op/roi_align_rotated/roi_align_rotated.py
@@ -0,0 +1,66 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/roi_align_rotated.py
+"""
+
+import paddle
+import paddle.nn as nn
+from paddle.utils.cpp_extension import load
+custom_ops = load(
+ name="custom_jit_ops",
+ sources=[
+ "ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc",
+ "ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu"
+ ])
+
+roi_align_rotated = custom_ops.roi_align_rotated
+
+
+class RoIAlignRotated(nn.Layer):
+ """RoI align pooling layer for rotated proposals.
+
+ """
+
+ def __init__(self,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ super(RoIAlignRotated, self).__init__()
+
+ if isinstance(out_size, int):
+ self.out_h = out_size
+ self.out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ self.out_h, self.out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+
+ self.spatial_scale = float(spatial_scale)
+ self.sample_num = int(sample_num)
+ self.aligned = aligned
+ self.clockwise = clockwise
+
+ def forward(self, feats, rois):
+ output = roi_align_rotated(feats, rois, self.out_h, self.out_w,
+ self.spatial_scale, self.sample_num,
+ self.aligned, self.clockwise)
+ return output
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index ffee0a93e3993d075be8be510513bbb00012fc74..6abaa408b3f6995a0b4c377206e8a1551b48c56b 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -26,6 +26,7 @@ from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss
from .det_ct_loss import CTLoss
+from .det_drrg_loss import DRRGLoss
# rec loss
from .rec_ctc_loss import CTCLoss
@@ -70,7 +71,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss', 'RFLLoss'
+ 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/det_drrg_loss.py b/ppocr/losses/det_drrg_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d4b521c7d1d0f29104abc3f315379827f98af7
--- /dev/null
+++ b/ppocr/losses/det_drrg_loss.py
@@ -0,0 +1,224 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/drrg_loss.py
+"""
+
+import paddle
+import paddle.nn.functional as F
+from paddle import nn
+
+
+class DRRGLoss(nn.Layer):
+ def __init__(self, ohem_ratio=3.0):
+ super().__init__()
+ self.ohem_ratio = ohem_ratio
+ self.downsample_ratio = 1.0
+
+ def balance_bce_loss(self, pred, gt, mask):
+ """Balanced Binary-CrossEntropy Loss.
+
+ Args:
+ pred (Tensor): Shape of :math:`(1, H, W)`.
+ gt (Tensor): Shape of :math:`(1, H, W)`.
+ mask (Tensor): Shape of :math:`(1, H, W)`.
+
+ Returns:
+ Tensor: Balanced bce loss.
+ """
+ assert pred.shape == gt.shape == mask.shape
+ assert paddle.all(pred >= 0) and paddle.all(pred <= 1)
+ assert paddle.all(gt >= 0) and paddle.all(gt <= 1)
+ positive = gt * mask
+ negative = (1 - gt) * mask
+ positive_count = int(positive.sum())
+
+ if positive_count > 0:
+ loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ positive_loss = paddle.sum(loss * positive)
+ negative_loss = loss * negative
+ negative_count = min(
+ int(negative.sum()), int(positive_count * self.ohem_ratio))
+ else:
+ positive_loss = paddle.to_tensor(0.0)
+ loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ negative_loss = loss * negative
+ negative_count = 100
+ negative_loss, _ = paddle.topk(
+ negative_loss.reshape([-1]), negative_count)
+
+ balance_loss = (positive_loss + paddle.sum(negative_loss)) / (
+ float(positive_count + negative_count) + 1e-5)
+
+ return balance_loss
+
+ def gcn_loss(self, gcn_data):
+ """CrossEntropy Loss from gcn module.
+
+ Args:
+ gcn_data (tuple(Tensor, Tensor)): The first is the
+ prediction with shape :math:`(N, 2)` and the
+ second is the gt label with shape :math:`(m, n)`
+ where :math:`m * n = N`.
+
+ Returns:
+ Tensor: CrossEntropy loss.
+ """
+ gcn_pred, gt_labels = gcn_data
+ gt_labels = gt_labels.reshape([-1])
+ loss = F.cross_entropy(gcn_pred, gt_labels)
+
+ return loss
+
+ def bitmasks2tensor(self, bitmasks, target_sz):
+ """Convert Bitmasks to tensor.
+
+ Args:
+ bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
+ for one img.
+ target_sz (tuple(int, int)): The target tensor of size
+ :math:`(H, W)`.
+
+ Returns:
+ list[Tensor]: The list of kernel tensors. Each element stands for
+ one kernel level.
+ """
+ batch_size = len(bitmasks)
+ results = []
+
+ kernel = []
+ for batch_inx in range(batch_size):
+ mask = bitmasks[batch_inx]
+ # hxw
+ mask_sz = mask.shape
+ # left, right, top, bottom
+ pad = [0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]]
+ mask = F.pad(mask, pad, mode='constant', value=0)
+ kernel.append(mask)
+ kernel = paddle.stack(kernel)
+ results.append(kernel)
+
+ return results
+
+ def forward(self, preds, labels):
+ """Compute Drrg loss.
+ """
+
+ assert isinstance(preds, tuple)
+ gt_text_mask, gt_center_region_mask, gt_mask, gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map = labels[
+ 1:8]
+
+ downsample_ratio = self.downsample_ratio
+
+ pred_maps, gcn_data = preds
+ pred_text_region = pred_maps[:, 0, :, :]
+ pred_center_region = pred_maps[:, 1, :, :]
+ pred_sin_map = pred_maps[:, 2, :, :]
+ pred_cos_map = pred_maps[:, 3, :, :]
+ pred_top_height_map = pred_maps[:, 4, :, :]
+ pred_bot_height_map = pred_maps[:, 5, :, :]
+ feature_sz = pred_maps.shape
+
+ # bitmask 2 tensor
+ mapping = {
+ 'gt_text_mask': paddle.cast(gt_text_mask, 'float32'),
+ 'gt_center_region_mask':
+ paddle.cast(gt_center_region_mask, 'float32'),
+ 'gt_mask': paddle.cast(gt_mask, 'float32'),
+ 'gt_top_height_map': paddle.cast(gt_top_height_map, 'float32'),
+ 'gt_bot_height_map': paddle.cast(gt_bot_height_map, 'float32'),
+ 'gt_sin_map': paddle.cast(gt_sin_map, 'float32'),
+ 'gt_cos_map': paddle.cast(gt_cos_map, 'float32')
+ }
+ gt = {}
+ for key, value in mapping.items():
+ gt[key] = value
+ if abs(downsample_ratio - 1.0) < 1e-2:
+ gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
+ else:
+ gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
+ gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
+ if key in ['gt_top_height_map', 'gt_bot_height_map']:
+ gt[key] = [item * downsample_ratio for item in gt[key]]
+ gt[key] = [item for item in gt[key]]
+
+ scale = paddle.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
+ pred_sin_map = pred_sin_map * scale
+ pred_cos_map = pred_cos_map * scale
+
+ loss_text = self.balance_bce_loss(
+ F.sigmoid(pred_text_region), gt['gt_text_mask'][0],
+ gt['gt_mask'][0])
+
+ text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0])
+ negative_text_mask = ((1 - gt['gt_text_mask'][0]) * gt['gt_mask'][0])
+ loss_center_map = F.binary_cross_entropy(
+ F.sigmoid(pred_center_region),
+ gt['gt_center_region_mask'][0],
+ reduction='none')
+ if int(text_mask.sum()) > 0:
+ loss_center_positive = paddle.sum(loss_center_map *
+ text_mask) / paddle.sum(text_mask)
+ else:
+ loss_center_positive = paddle.to_tensor(0.0)
+ loss_center_negative = paddle.sum(
+ loss_center_map *
+ negative_text_mask) / paddle.sum(negative_text_mask)
+ loss_center = loss_center_positive + 0.5 * loss_center_negative
+
+ center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0])
+ if int(center_mask.sum()) > 0:
+ map_sz = pred_top_height_map.shape
+ ones = paddle.ones(map_sz, dtype='float32')
+ loss_top = F.smooth_l1_loss(
+ pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
+ ones,
+ reduction='none')
+ loss_bot = F.smooth_l1_loss(
+ pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
+ ones,
+ reduction='none')
+ gt_height = (
+ gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
+ loss_height = paddle.sum(
+ (paddle.log(gt_height + 1) *
+ (loss_top + loss_bot)) * center_mask) / paddle.sum(center_mask)
+
+ loss_sin = paddle.sum(
+ F.smooth_l1_loss(
+ pred_sin_map, gt['gt_sin_map'][0],
+ reduction='none') * center_mask) / paddle.sum(center_mask)
+ loss_cos = paddle.sum(
+ F.smooth_l1_loss(
+ pred_cos_map, gt['gt_cos_map'][0],
+ reduction='none') * center_mask) / paddle.sum(center_mask)
+ else:
+ loss_height = paddle.to_tensor(0.0)
+ loss_sin = paddle.to_tensor(0.0)
+ loss_cos = paddle.to_tensor(0.0)
+
+ loss_gcn = self.gcn_loss(gcn_data)
+
+ loss = loss_text + loss_center + loss_height + loss_sin + loss_cos + loss_gcn
+ results = dict(
+ loss=loss,
+ loss_text=loss_text,
+ loss_center=loss_center,
+ loss_height=loss_height,
+ loss_sin=loss_sin,
+ loss_cos=loss_cos,
+ loss_gcn=loss_gcn)
+
+ return results
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index ba180566c0c522154e7c16dc0eb4a6ec6cc5fe3d..63002140c5be4bd7e32b56995c6410ecc8a0fa36 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -24,6 +24,7 @@ def build_head(config):
from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead
from .det_ct_head import CT_Head
+ from .det_drrg_head import DRRGHead
# rec head
from .rec_ctc_head import CTCHead
@@ -54,7 +55,8 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
- 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
+ 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
+ 'DRRGHead'
]
#table head
diff --git a/ppocr/modeling/heads/det_drrg_head.py b/ppocr/modeling/heads/det_drrg_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aee1f8cb7734fd6093cd6ed11e5492ef5cd9785
--- /dev/null
+++ b/ppocr/modeling/heads/det_drrg_head.py
@@ -0,0 +1,191 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/drrg_head.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+import cv2
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from .gcn import GCN
+from .local_graph import LocalGraphs
+from .proposal_local_graph import ProposalLocalGraphs
+
+
+class DRRGHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ k_at_hops=(8, 4),
+ num_adjacent_linkages=3,
+ node_geo_feat_len=120,
+ pooling_scale=1.0,
+ pooling_output_size=(4, 3),
+ nms_thr=0.3,
+ min_width=8.0,
+ max_width=24.0,
+ comp_shrink_ratio=1.03,
+ comp_ratio=0.4,
+ comp_score_thr=0.3,
+ text_region_thr=0.2,
+ center_region_thr=0.2,
+ center_region_area_thr=50,
+ local_graph_thr=0.7,
+ **kwargs):
+ super().__init__()
+
+ assert isinstance(in_channels, int)
+ assert isinstance(k_at_hops, tuple)
+ assert isinstance(num_adjacent_linkages, int)
+ assert isinstance(node_geo_feat_len, int)
+ assert isinstance(pooling_scale, float)
+ assert isinstance(pooling_output_size, tuple)
+ assert isinstance(comp_shrink_ratio, float)
+ assert isinstance(nms_thr, float)
+ assert isinstance(min_width, float)
+ assert isinstance(max_width, float)
+ assert isinstance(comp_ratio, float)
+ assert isinstance(comp_score_thr, float)
+ assert isinstance(text_region_thr, float)
+ assert isinstance(center_region_thr, float)
+ assert isinstance(center_region_area_thr, int)
+ assert isinstance(local_graph_thr, float)
+
+ self.in_channels = in_channels
+ self.out_channels = 6
+ self.downsample_ratio = 1.0
+ self.k_at_hops = k_at_hops
+ self.num_adjacent_linkages = num_adjacent_linkages
+ self.node_geo_feat_len = node_geo_feat_len
+ self.pooling_scale = pooling_scale
+ self.pooling_output_size = pooling_output_size
+ self.comp_shrink_ratio = comp_shrink_ratio
+ self.nms_thr = nms_thr
+ self.min_width = min_width
+ self.max_width = max_width
+ self.comp_ratio = comp_ratio
+ self.comp_score_thr = comp_score_thr
+ self.text_region_thr = text_region_thr
+ self.center_region_thr = center_region_thr
+ self.center_region_area_thr = center_region_area_thr
+ self.local_graph_thr = local_graph_thr
+
+ self.out_conv = nn.Conv2D(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.graph_train = LocalGraphs(
+ self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
+ self.pooling_scale, self.pooling_output_size, self.local_graph_thr)
+
+ self.graph_test = ProposalLocalGraphs(
+ self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
+ self.pooling_scale, self.pooling_output_size, self.nms_thr,
+ self.min_width, self.max_width, self.comp_shrink_ratio,
+ self.comp_ratio, self.comp_score_thr, self.text_region_thr,
+ self.center_region_thr, self.center_region_area_thr)
+
+ pool_w, pool_h = self.pooling_output_size
+ node_feat_len = (pool_w * pool_h) * (
+ self.in_channels + self.out_channels) + self.node_geo_feat_len
+ self.gcn = GCN(node_feat_len)
+
+ def forward(self, inputs, targets=None):
+ """
+ Args:
+ inputs (Tensor): Shape of :math:`(N, C, H, W)`.
+ gt_comp_attribs (list[ndarray]): The padded text component
+ attributes. Shape: (num_component, 8).
+
+ Returns:
+ tuple: Returns (pred_maps, (gcn_pred, gt_labels)).
+
+ - | pred_maps (Tensor): Prediction map with shape
+ :math:`(N, C_{out}, H, W)`.
+ - | gcn_pred (Tensor): Prediction from GCN module, with
+ shape :math:`(N, 2)`.
+ - | gt_labels (Tensor): Ground-truth label with shape
+ :math:`(N, 8)`.
+ """
+ if self.training:
+ assert targets is not None
+ gt_comp_attribs = targets[7]
+ pred_maps = self.out_conv(inputs)
+ feat_maps = paddle.concat([inputs, pred_maps], axis=1)
+ node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
+ feat_maps, np.stack(gt_comp_attribs))
+
+ gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
+
+ return pred_maps, (gcn_pred, gt_labels)
+ else:
+ return self.single_test(inputs)
+
+ def single_test(self, feat_maps):
+ r"""
+ Args:
+ feat_maps (Tensor): Shape of :math:`(N, C, H, W)`.
+
+ Returns:
+ tuple: Returns (edge, score, text_comps).
+
+ - | edge (ndarray): The edge array of shape :math:`(N, 2)`
+ where each row is a pair of text component indices
+ that makes up an edge in graph.
+ - | score (ndarray): The score array of shape :math:`(N,)`,
+ corresponding to the edge above.
+ - | text_comps (ndarray): The text components of shape
+ :math:`(N, 9)` where each row corresponds to one box and
+ its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
+ """
+ pred_maps = self.out_conv(feat_maps)
+ feat_maps = paddle.concat([feat_maps, pred_maps], axis=1)
+
+ none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
+
+ (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
+ pivot_local_graphs, text_comps) = graph_data
+
+ if none_flag:
+ return None, None, None
+ gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
+ pivots_knn_inds)
+ pred_labels = F.softmax(gcn_pred, axis=1)
+
+ edges = []
+ scores = []
+ pivot_local_graphs = pivot_local_graphs.squeeze().numpy()
+
+ for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
+ pivot = pivot_local_graph[0]
+ for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
+ neighbor = pivot_local_graph[neighbor_ind.item()]
+ edges.append([pivot, neighbor])
+ scores.append(pred_labels[pivot_ind * pivots_knn_inds.shape[1] +
+ k_ind, 1].item())
+
+ edges = np.asarray(edges)
+ scores = np.asarray(scores)
+
+ return edges, scores, text_comps
diff --git a/ppocr/modeling/heads/gcn.py b/ppocr/modeling/heads/gcn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d123f067cb7640575e7b6cfdeb0ab1826ab62aab
--- /dev/null
+++ b/ppocr/modeling/heads/gcn.py
@@ -0,0 +1,113 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/gcn.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class BatchNorm1D(nn.BatchNorm1D):
+ def __init__(self,
+ num_features,
+ eps=1e-05,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True):
+ momentum = 1 - momentum
+ weight_attr = None
+ bias_attr = None
+ if not affine:
+ weight_attr = paddle.ParamAttr(learning_rate=0.0)
+ bias_attr = paddle.ParamAttr(learning_rate=0.0)
+ super().__init__(
+ num_features,
+ momentum=momentum,
+ epsilon=eps,
+ weight_attr=weight_attr,
+ bias_attr=bias_attr,
+ use_global_stats=track_running_stats)
+
+
+class MeanAggregator(nn.Layer):
+ def forward(self, features, A):
+ x = paddle.bmm(A, features)
+ return x
+
+
+class GraphConv(nn.Layer):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.weight = self.create_parameter(
+ [in_dim * 2, out_dim],
+ default_initializer=nn.initializer.XavierUniform())
+ self.bias = self.create_parameter(
+ [out_dim],
+ is_bias=True,
+ default_initializer=nn.initializer.Assign([0] * out_dim))
+
+ self.aggregator = MeanAggregator()
+
+ def forward(self, features, A):
+ b, n, d = features.shape
+ assert d == self.in_dim
+ agg_feats = self.aggregator(features, A)
+ cat_feats = paddle.concat([features, agg_feats], axis=2)
+ out = paddle.einsum('bnd,df->bnf', cat_feats, self.weight)
+ out = F.relu(out + self.bias)
+ return out
+
+
+class GCN(nn.Layer):
+ def __init__(self, feat_len):
+ super(GCN, self).__init__()
+ self.bn0 = BatchNorm1D(feat_len, affine=False)
+ self.conv1 = GraphConv(feat_len, 512)
+ self.conv2 = GraphConv(512, 256)
+ self.conv3 = GraphConv(256, 128)
+ self.conv4 = GraphConv(128, 64)
+ self.classifier = nn.Sequential(
+ nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2))
+
+ def forward(self, x, A, knn_inds):
+
+ num_local_graphs, num_max_nodes, feat_len = x.shape
+
+ x = x.reshape([-1, feat_len])
+ x = self.bn0(x)
+ x = x.reshape([num_local_graphs, num_max_nodes, feat_len])
+
+ x = self.conv1(x, A)
+ x = self.conv2(x, A)
+ x = self.conv3(x, A)
+ x = self.conv4(x, A)
+ k = knn_inds.shape[-1]
+ mid_feat_len = x.shape[-1]
+ edge_feat = paddle.zeros([num_local_graphs, k, mid_feat_len])
+ for graph_ind in range(num_local_graphs):
+ edge_feat[graph_ind, :, :] = x[graph_ind][paddle.to_tensor(knn_inds[
+ graph_ind])]
+ edge_feat = edge_feat.reshape([-1, mid_feat_len])
+ pred = self.classifier(edge_feat)
+
+ return pred
diff --git a/ppocr/modeling/heads/local_graph.py b/ppocr/modeling/heads/local_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..50fe6d72236df7afc2de3fda9e2e5db404641f34
--- /dev/null
+++ b/ppocr/modeling/heads/local_graph.py
@@ -0,0 +1,388 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/local_graph.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from ppocr.ext_op import RoIAlignRotated
+
+
+def normalize_adjacent_matrix(A):
+ assert A.ndim == 2
+ assert A.shape[0] == A.shape[1]
+
+ A = A + np.eye(A.shape[0])
+ d = np.sum(A, axis=0)
+ d = np.clip(d, 0, None)
+ d_inv = np.power(d, -0.5).flatten()
+ d_inv[np.isinf(d_inv)] = 0.0
+ d_inv = np.diag(d_inv)
+ G = A.dot(d_inv).transpose().dot(d_inv)
+ return G
+
+
+def euclidean_distance_matrix(A, B):
+ """Calculate the Euclidean distance matrix.
+
+ Args:
+ A (ndarray): The point sequence.
+ B (ndarray): The point sequence with the same dimensions as A.
+
+ returns:
+ D (ndarray): The Euclidean distance matrix.
+ """
+ assert A.ndim == 2
+ assert B.ndim == 2
+ assert A.shape[1] == B.shape[1]
+
+ m = A.shape[0]
+ n = B.shape[0]
+
+ A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
+ B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
+ D_squared = A_dots + B_dots - 2 * A.dot(B.T)
+
+ zero_mask = np.less(D_squared, 0.0)
+ D_squared[zero_mask] = 0.0
+ D = np.sqrt(D_squared)
+ return D
+
+
+def feature_embedding(input_feats, out_feat_len):
+ """Embed features. This code was partially adapted from
+ https://github.com/GXYM/DRRG licensed under the MIT license.
+
+ Args:
+ input_feats (ndarray): The input features of shape (N, d), where N is
+ the number of nodes in graph, d is the input feature vector length.
+ out_feat_len (int): The length of output feature vector.
+
+ Returns:
+ embedded_feats (ndarray): The embedded features.
+ """
+ assert input_feats.ndim == 2
+ assert isinstance(out_feat_len, int)
+ assert out_feat_len >= input_feats.shape[1]
+
+ num_nodes = input_feats.shape[0]
+ feat_dim = input_feats.shape[1]
+ feat_repeat_times = out_feat_len // feat_dim
+ residue_dim = out_feat_len % feat_dim
+
+ if residue_dim > 0:
+ embed_wave = np.array([
+ np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
+ for j in range(feat_repeat_times + 1)
+ ]).reshape((feat_repeat_times + 1, 1, 1))
+ repeat_feats = np.repeat(
+ np.expand_dims(
+ input_feats, axis=0), feat_repeat_times, axis=0)
+ residue_feats = np.hstack([
+ input_feats[:, 0:residue_dim], np.zeros(
+ (num_nodes, feat_dim - residue_dim))
+ ])
+ residue_feats = np.expand_dims(residue_feats, axis=0)
+ repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
+ embedded_feats = repeat_feats / embed_wave
+ embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
+ embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
+ embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
+ (num_nodes, -1))[:, 0:out_feat_len]
+ else:
+ embed_wave = np.array([
+ np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
+ for j in range(feat_repeat_times)
+ ]).reshape((feat_repeat_times, 1, 1))
+ repeat_feats = np.repeat(
+ np.expand_dims(
+ input_feats, axis=0), feat_repeat_times, axis=0)
+ embedded_feats = repeat_feats / embed_wave
+ embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
+ embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
+ embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
+ (num_nodes, -1)).astype(np.float32)
+
+ return embedded_feats
+
+
+class LocalGraphs:
+ def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
+ pooling_scale, pooling_output_size, local_graph_thr):
+
+ assert len(k_at_hops) == 2
+ assert all(isinstance(n, int) for n in k_at_hops)
+ assert isinstance(num_adjacent_linkages, int)
+ assert isinstance(node_geo_feat_len, int)
+ assert isinstance(pooling_scale, float)
+ assert all(isinstance(n, int) for n in pooling_output_size)
+ assert isinstance(local_graph_thr, float)
+
+ self.k_at_hops = k_at_hops
+ self.num_adjacent_linkages = num_adjacent_linkages
+ self.node_geo_feat_dim = node_geo_feat_len
+ self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
+ self.local_graph_thr = local_graph_thr
+
+ def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
+ """Generate local graphs for GCN to predict which instance a text
+ component belongs to.
+
+ Args:
+ sorted_dist_inds (ndarray): The complete graph node indices, which
+ is sorted according to the Euclidean distance.
+ gt_comp_labels(ndarray): The ground truth labels define the
+ instance to which the text components (nodes in graphs) belong.
+
+ Returns:
+ pivot_local_graphs(list[list[int]]): The list of local graph
+ neighbor indices of pivots.
+ pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
+ of pivots.
+ """
+
+ assert sorted_dist_inds.ndim == 2
+ assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
+ gt_comp_labels.shape[0])
+
+ knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
+ pivot_local_graphs = []
+ pivot_knns = []
+ for pivot_ind, knn in enumerate(knn_graph):
+
+ local_graph_neighbors = set(knn)
+
+ for neighbor_ind in knn:
+ local_graph_neighbors.update(
+ set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
+ 1]))
+
+ local_graph_neighbors.discard(pivot_ind)
+ pivot_local_graph = list(local_graph_neighbors)
+ pivot_local_graph.insert(0, pivot_ind)
+ pivot_knn = [pivot_ind] + list(knn)
+
+ if pivot_ind < 1:
+ pivot_local_graphs.append(pivot_local_graph)
+ pivot_knns.append(pivot_knn)
+ else:
+ add_flag = True
+ for graph_ind, added_knn in enumerate(pivot_knns):
+ added_pivot_ind = added_knn[0]
+ added_local_graph = pivot_local_graphs[graph_ind]
+
+ union = len(
+ set(pivot_local_graph[1:]).union(
+ set(added_local_graph[1:])))
+ intersect = len(
+ set(pivot_local_graph[1:]).intersection(
+ set(added_local_graph[1:])))
+ local_graph_iou = intersect / (union + 1e-8)
+
+ if (local_graph_iou > self.local_graph_thr and
+ pivot_ind in added_knn and
+ gt_comp_labels[added_pivot_ind] ==
+ gt_comp_labels[pivot_ind] and
+ gt_comp_labels[pivot_ind] != 0):
+ add_flag = False
+ break
+ if add_flag:
+ pivot_local_graphs.append(pivot_local_graph)
+ pivot_knns.append(pivot_knn)
+
+ return pivot_local_graphs, pivot_knns
+
+ def generate_gcn_input(self, node_feat_batch, node_label_batch,
+ local_graph_batch, knn_batch, sorted_dist_ind_batch):
+ """Generate graph convolution network input data.
+
+ Args:
+ node_feat_batch (List[Tensor]): The batched graph node features.
+ node_label_batch (List[ndarray]): The batched text component
+ labels.
+ local_graph_batch (List[List[list[int]]]): The local graph node
+ indices of image batch.
+ knn_batch (List[List[list[int]]]): The knn graph node indices of
+ image batch.
+ sorted_dist_ind_batch (list[ndarray]): The node indices sorted
+ according to the Euclidean distance.
+
+ Returns:
+ local_graphs_node_feat (Tensor): The node features of graph.
+ adjacent_matrices (Tensor): The adjacent matrices of local graphs.
+ pivots_knn_inds (Tensor): The k-nearest neighbor indices in
+ local graph.
+ gt_linkage (Tensor): The surpervision signal of GCN for linkage
+ prediction.
+ """
+ assert isinstance(node_feat_batch, list)
+ assert isinstance(node_label_batch, list)
+ assert isinstance(local_graph_batch, list)
+ assert isinstance(knn_batch, list)
+ assert isinstance(sorted_dist_ind_batch, list)
+
+ num_max_nodes = max([
+ len(pivot_local_graph)
+ for pivot_local_graphs in local_graph_batch
+ for pivot_local_graph in pivot_local_graphs
+ ])
+
+ local_graphs_node_feat = []
+ adjacent_matrices = []
+ pivots_knn_inds = []
+ pivots_gt_linkage = []
+
+ for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
+ node_feats = node_feat_batch[batch_ind]
+ pivot_local_graphs = local_graph_batch[batch_ind]
+ pivot_knns = knn_batch[batch_ind]
+ node_labels = node_label_batch[batch_ind]
+
+ for graph_ind, pivot_knn in enumerate(pivot_knns):
+ pivot_local_graph = pivot_local_graphs[graph_ind]
+ num_nodes = len(pivot_local_graph)
+ pivot_ind = pivot_local_graph[0]
+ node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
+
+ knn_inds = paddle.to_tensor(
+ [node2ind_map[i] for i in pivot_knn[1:]])
+ pivot_feats = node_feats[pivot_ind]
+ normalized_feats = node_feats[paddle.to_tensor(
+ pivot_local_graph)] - pivot_feats
+
+ adjacent_matrix = np.zeros(
+ (num_nodes, num_nodes), dtype=np.float32)
+ for node in pivot_local_graph:
+ neighbors = sorted_dist_inds[node, 1:
+ self.num_adjacent_linkages + 1]
+ for neighbor in neighbors:
+ if neighbor in pivot_local_graph:
+
+ adjacent_matrix[node2ind_map[node], node2ind_map[
+ neighbor]] = 1
+ adjacent_matrix[node2ind_map[neighbor],
+ node2ind_map[node]] = 1
+
+ adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
+ pad_adjacent_matrix = paddle.zeros(
+ (num_max_nodes, num_max_nodes))
+ pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
+ paddle.to_tensor(adjacent_matrix), 'float32')
+
+ pad_normalized_feats = paddle.concat(
+ [
+ normalized_feats, paddle.zeros(
+ (num_max_nodes - num_nodes,
+ normalized_feats.shape[1]))
+ ],
+ axis=0)
+ local_graph_labels = node_labels[pivot_local_graph]
+ knn_labels = local_graph_labels[knn_inds.numpy()]
+ link_labels = ((node_labels[pivot_ind] == knn_labels) &
+ (node_labels[pivot_ind] > 0)).astype(np.int64)
+ link_labels = paddle.to_tensor(link_labels)
+
+ local_graphs_node_feat.append(pad_normalized_feats)
+ adjacent_matrices.append(pad_adjacent_matrix)
+ pivots_knn_inds.append(knn_inds)
+ pivots_gt_linkage.append(link_labels)
+
+ local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
+ adjacent_matrices = paddle.stack(adjacent_matrices, 0)
+ pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
+ pivots_gt_linkage = paddle.stack(pivots_gt_linkage, 0)
+
+ return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
+ pivots_gt_linkage)
+
+ def __call__(self, feat_maps, comp_attribs):
+ """Generate local graphs as GCN input.
+
+ Args:
+ feat_maps (Tensor): The feature maps to extract the content
+ features of text components.
+ comp_attribs (ndarray): The text component attributes.
+
+ Returns:
+ local_graphs_node_feat (Tensor): The node features of graph.
+ adjacent_matrices (Tensor): The adjacent matrices of local graphs.
+ pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
+ graph.
+ gt_linkage (Tensor): The surpervision signal of GCN for linkage
+ prediction.
+ """
+
+ assert isinstance(feat_maps, paddle.Tensor)
+ assert comp_attribs.ndim == 3
+ assert comp_attribs.shape[2] == 8
+
+ sorted_dist_inds_batch = []
+ local_graph_batch = []
+ knn_batch = []
+ node_feat_batch = []
+ node_label_batch = []
+
+ for batch_ind in range(comp_attribs.shape[0]):
+ num_comps = int(comp_attribs[batch_ind, 0, 0])
+ comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
+ node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(
+ np.int32)
+
+ comp_centers = comp_geo_attribs[:, 0:2]
+ distance_matrix = euclidean_distance_matrix(comp_centers,
+ comp_centers)
+
+ batch_id = np.zeros(
+ (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
+ comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
+ angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
+ comp_geo_attribs[:, -1])
+ angle = angle.reshape((-1, 1))
+ rotated_rois = np.hstack(
+ [batch_id, comp_geo_attribs[:, :-2], angle])
+ rois = paddle.to_tensor(rotated_rois)
+ content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
+ rois)
+
+ content_feats = content_feats.reshape([content_feats.shape[0], -1])
+ geo_feats = feature_embedding(comp_geo_attribs,
+ self.node_geo_feat_dim)
+ geo_feats = paddle.to_tensor(geo_feats)
+ node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
+
+ sorted_dist_inds = np.argsort(distance_matrix, axis=1)
+ pivot_local_graphs, pivot_knns = self.generate_local_graphs(
+ sorted_dist_inds, node_labels)
+
+ node_feat_batch.append(node_feats)
+ node_label_batch.append(node_labels)
+ local_graph_batch.append(pivot_local_graphs)
+ knn_batch.append(pivot_knns)
+ sorted_dist_inds_batch.append(sorted_dist_inds)
+
+ (node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
+ self.generate_gcn_input(node_feat_batch,
+ node_label_batch,
+ local_graph_batch,
+ knn_batch,
+ sorted_dist_inds_batch)
+
+ return node_feats, adjacent_matrices, knn_inds, gt_linkage
diff --git a/ppocr/modeling/heads/proposal_local_graph.py b/ppocr/modeling/heads/proposal_local_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..7887c4ff42f8ae9d1826a71f01208cd81bb2d52c
--- /dev/null
+++ b/ppocr/modeling/heads/proposal_local_graph.py
@@ -0,0 +1,412 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/proposal_local_graph.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cv2
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from lanms import merge_quadrangle_n9 as la_nms
+
+from ppocr.ext_op import RoIAlignRotated
+from .local_graph import (euclidean_distance_matrix, feature_embedding,
+ normalize_adjacent_matrix)
+
+
+def fill_hole(input_mask):
+ h, w = input_mask.shape
+ canvas = np.zeros((h + 2, w + 2), np.uint8)
+ canvas[1:h + 1, 1:w + 1] = input_mask.copy()
+
+ mask = np.zeros((h + 4, w + 4), np.uint8)
+
+ cv2.floodFill(canvas, mask, (0, 0), 1)
+ canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
+
+ return ~canvas | input_mask
+
+
+class ProposalLocalGraphs:
+ def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
+ pooling_scale, pooling_output_size, nms_thr, min_width,
+ max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
+ text_region_thr, center_region_thr, center_region_area_thr):
+
+ assert len(k_at_hops) == 2
+ assert isinstance(k_at_hops, tuple)
+ assert isinstance(num_adjacent_linkages, int)
+ assert isinstance(node_geo_feat_len, int)
+ assert isinstance(pooling_scale, float)
+ assert isinstance(pooling_output_size, tuple)
+ assert isinstance(nms_thr, float)
+ assert isinstance(min_width, float)
+ assert isinstance(max_width, float)
+ assert isinstance(comp_shrink_ratio, float)
+ assert isinstance(comp_w_h_ratio, float)
+ assert isinstance(comp_score_thr, float)
+ assert isinstance(text_region_thr, float)
+ assert isinstance(center_region_thr, float)
+ assert isinstance(center_region_area_thr, int)
+
+ self.k_at_hops = k_at_hops
+ self.active_connection = num_adjacent_linkages
+ self.local_graph_depth = len(self.k_at_hops)
+ self.node_geo_feat_dim = node_geo_feat_len
+ self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
+ self.nms_thr = nms_thr
+ self.min_width = min_width
+ self.max_width = max_width
+ self.comp_shrink_ratio = comp_shrink_ratio
+ self.comp_w_h_ratio = comp_w_h_ratio
+ self.comp_score_thr = comp_score_thr
+ self.text_region_thr = text_region_thr
+ self.center_region_thr = center_region_thr
+ self.center_region_area_thr = center_region_area_thr
+
+ def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
+ cos_map, comp_score_thr, min_width, max_width,
+ comp_shrink_ratio, comp_w_h_ratio):
+ """Propose text components.
+
+ Args:
+ score_map (ndarray): The score map for NMS.
+ top_height_map (ndarray): The predicted text height map from each
+ pixel in text center region to top sideline.
+ bot_height_map (ndarray): The predicted text height map from each
+ pixel in text center region to bottom sideline.
+ sin_map (ndarray): The predicted sin(theta) map.
+ cos_map (ndarray): The predicted cos(theta) map.
+ comp_score_thr (float): The score threshold of text component.
+ min_width (float): The minimum width of text components.
+ max_width (float): The maximum width of text components.
+ comp_shrink_ratio (float): The shrink ratio of text components.
+ comp_w_h_ratio (float): The width to height ratio of text
+ components.
+
+ Returns:
+ text_comps (ndarray): The text components.
+ """
+
+ comp_centers = np.argwhere(score_map > comp_score_thr)
+ comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
+ y = comp_centers[:, 0]
+ x = comp_centers[:, 1]
+
+ top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
+ bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
+ sin = sin_map[y, x].reshape((-1, 1))
+ cos = cos_map[y, x].reshape((-1, 1))
+
+ top_mid_pts = comp_centers + np.hstack(
+ [top_height * sin, top_height * cos])
+ bot_mid_pts = comp_centers - np.hstack(
+ [bot_height * sin, bot_height * cos])
+
+ width = (top_height + bot_height) * comp_w_h_ratio
+ width = np.clip(width, min_width, max_width)
+ r = width / 2
+
+ tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
+ tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
+ br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
+ bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
+ text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
+
+ score = score_map[y, x].reshape((-1, 1))
+ text_comps = np.hstack([text_comps, score])
+
+ return text_comps
+
+ def propose_comps_and_attribs(self, text_region_map, center_region_map,
+ top_height_map, bot_height_map, sin_map,
+ cos_map):
+ """Generate text components and attributes.
+
+ Args:
+ text_region_map (ndarray): The predicted text region probability
+ map.
+ center_region_map (ndarray): The predicted text center region
+ probability map.
+ top_height_map (ndarray): The predicted text height map from each
+ pixel in text center region to top sideline.
+ bot_height_map (ndarray): The predicted text height map from each
+ pixel in text center region to bottom sideline.
+ sin_map (ndarray): The predicted sin(theta) map.
+ cos_map (ndarray): The predicted cos(theta) map.
+
+ Returns:
+ comp_attribs (ndarray): The text component attributes.
+ text_comps (ndarray): The text components.
+ """
+
+ assert (text_region_map.shape == center_region_map.shape ==
+ top_height_map.shape == bot_height_map.shape == sin_map.shape ==
+ cos_map.shape)
+ text_mask = text_region_map > self.text_region_thr
+ center_region_mask = (
+ center_region_map > self.center_region_thr) * text_mask
+
+ scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8))
+ sin_map, cos_map = sin_map * scale, cos_map * scale
+
+ center_region_mask = fill_hole(center_region_mask)
+ center_region_contours, _ = cv2.findContours(
+ center_region_mask.astype(np.uint8), cv2.RETR_TREE,
+ cv2.CHAIN_APPROX_SIMPLE)
+
+ mask_sz = center_region_map.shape
+ comp_list = []
+ for contour in center_region_contours:
+ current_center_mask = np.zeros(mask_sz)
+ cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
+ if current_center_mask.sum() <= self.center_region_area_thr:
+ continue
+ score_map = text_region_map * current_center_mask
+
+ text_comps = self.propose_comps(
+ score_map, top_height_map, bot_height_map, sin_map, cos_map,
+ self.comp_score_thr, self.min_width, self.max_width,
+ self.comp_shrink_ratio, self.comp_w_h_ratio)
+
+ text_comps = la_nms(text_comps, self.nms_thr)
+ text_comp_mask = np.zeros(mask_sz)
+ text_comp_boxes = text_comps[:, :8].reshape(
+ (-1, 4, 2)).astype(np.int32)
+
+ cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1)
+ if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
+ continue
+ if text_comps.shape[-1] > 0:
+ comp_list.append(text_comps)
+
+ if len(comp_list) <= 0:
+ return None, None
+
+ text_comps = np.vstack(comp_list)
+ text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2))
+ centers = np.mean(text_comp_boxes, axis=1).astype(np.int32)
+ x = centers[:, 0]
+ y = centers[:, 1]
+
+ scores = []
+ for text_comp_box in text_comp_boxes:
+ text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
+ mask_sz[1] - 1)
+ text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
+ mask_sz[0] - 1)
+ min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
+ max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
+ text_comp_box = text_comp_box - min_coord
+ box_sz = (max_coord - min_coord + 1)
+ temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
+ cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
+ temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] + 1),
+ min_coord[0]:(max_coord[0] + 1)]
+ score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
+ scores.append(score)
+ scores = np.array(scores).reshape((-1, 1))
+ text_comps = np.hstack([text_comps[:, :-1], scores])
+
+ h = top_height_map[y, x].reshape(
+ (-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
+ w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
+ sin = sin_map[y, x].reshape((-1, 1))
+ cos = cos_map[y, x].reshape((-1, 1))
+
+ x = x.reshape((-1, 1))
+ y = y.reshape((-1, 1))
+ comp_attribs = np.hstack([x, y, h, w, cos, sin])
+
+ return comp_attribs, text_comps
+
+ def generate_local_graphs(self, sorted_dist_inds, node_feats):
+ """Generate local graphs and graph convolution network input data.
+
+ Args:
+ sorted_dist_inds (ndarray): The node indices sorted according to
+ the Euclidean distance.
+ node_feats (tensor): The features of nodes in graph.
+
+ Returns:
+ local_graphs_node_feats (tensor): The features of nodes in local
+ graphs.
+ adjacent_matrices (tensor): The adjacent matrices.
+ pivots_knn_inds (tensor): The k-nearest neighbor indices in
+ local graphs.
+ pivots_local_graphs (tensor): The indices of nodes in local
+ graphs.
+ """
+
+ assert sorted_dist_inds.ndim == 2
+ assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
+ node_feats.shape[0])
+
+ knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
+ pivot_local_graphs = []
+ pivot_knns = []
+
+ for pivot_ind, knn in enumerate(knn_graph):
+
+ local_graph_neighbors = set(knn)
+
+ for neighbor_ind in knn:
+ local_graph_neighbors.update(
+ set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
+ 1]))
+
+ local_graph_neighbors.discard(pivot_ind)
+ pivot_local_graph = list(local_graph_neighbors)
+ pivot_local_graph.insert(0, pivot_ind)
+ pivot_knn = [pivot_ind] + list(knn)
+
+ pivot_local_graphs.append(pivot_local_graph)
+ pivot_knns.append(pivot_knn)
+
+ num_max_nodes = max([
+ len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
+ ])
+
+ local_graphs_node_feat = []
+ adjacent_matrices = []
+ pivots_knn_inds = []
+ pivots_local_graphs = []
+
+ for graph_ind, pivot_knn in enumerate(pivot_knns):
+ pivot_local_graph = pivot_local_graphs[graph_ind]
+ num_nodes = len(pivot_local_graph)
+ pivot_ind = pivot_local_graph[0]
+ node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
+
+ knn_inds = paddle.cast(
+ paddle.to_tensor([node2ind_map[i]
+ for i in pivot_knn[1:]]), 'int64')
+ pivot_feats = node_feats[pivot_ind]
+ normalized_feats = node_feats[paddle.to_tensor(
+ pivot_local_graph)] - pivot_feats
+
+ adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
+ for node in pivot_local_graph:
+ neighbors = sorted_dist_inds[node, 1:self.active_connection + 1]
+ for neighbor in neighbors:
+ if neighbor in pivot_local_graph:
+ adjacent_matrix[node2ind_map[node], node2ind_map[
+ neighbor]] = 1
+ adjacent_matrix[node2ind_map[neighbor], node2ind_map[
+ node]] = 1
+
+ adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
+ pad_adjacent_matrix = paddle.zeros((num_max_nodes, num_max_nodes), )
+ pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
+ paddle.to_tensor(adjacent_matrix), 'float32')
+
+ pad_normalized_feats = paddle.concat(
+ [
+ normalized_feats, paddle.zeros(
+ (num_max_nodes - num_nodes, normalized_feats.shape[1]),
+ )
+ ],
+ axis=0)
+
+ local_graph_nodes = paddle.to_tensor(pivot_local_graph)
+ local_graph_nodes = paddle.concat(
+ [
+ local_graph_nodes, paddle.zeros(
+ [num_max_nodes - num_nodes], dtype='int64')
+ ],
+ axis=-1)
+
+ local_graphs_node_feat.append(pad_normalized_feats)
+ adjacent_matrices.append(pad_adjacent_matrix)
+ pivots_knn_inds.append(knn_inds)
+ pivots_local_graphs.append(local_graph_nodes)
+
+ local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
+ adjacent_matrices = paddle.stack(adjacent_matrices, 0)
+ pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
+ pivots_local_graphs = paddle.stack(pivots_local_graphs, 0)
+
+ return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
+ pivots_local_graphs)
+
+ def __call__(self, preds, feat_maps):
+ """Generate local graphs and graph convolutional network input data.
+
+ Args:
+ preds (tensor): The predicted maps.
+ feat_maps (tensor): The feature maps to extract content feature of
+ text components.
+
+ Returns:
+ none_flag (bool): The flag showing whether the number of proposed
+ text components is 0.
+ local_graphs_node_feats (tensor): The features of nodes in local
+ graphs.
+ adjacent_matrices (tensor): The adjacent matrices.
+ pivots_knn_inds (tensor): The k-nearest neighbor indices in
+ local graphs.
+ pivots_local_graphs (tensor): The indices of nodes in local
+ graphs.
+ text_comps (ndarray): The predicted text components.
+ """
+ if preds.ndim == 4:
+ assert preds.shape[0] == 1
+ preds = paddle.squeeze(preds)
+ pred_text_region = F.sigmoid(preds[0]).numpy()
+ pred_center_region = F.sigmoid(preds[1]).numpy()
+ pred_sin_map = preds[2].numpy()
+ pred_cos_map = preds[3].numpy()
+ pred_top_height_map = preds[4].numpy()
+ pred_bot_height_map = preds[5].numpy()
+
+ comp_attribs, text_comps = self.propose_comps_and_attribs(
+ pred_text_region, pred_center_region, pred_top_height_map,
+ pred_bot_height_map, pred_sin_map, pred_cos_map)
+
+ if comp_attribs is None or len(comp_attribs) < 2:
+ none_flag = True
+ return none_flag, (0, 0, 0, 0, 0)
+
+ comp_centers = comp_attribs[:, 0:2]
+ distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
+
+ geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
+ geo_feats = paddle.to_tensor(geo_feats)
+
+ batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
+ comp_attribs = comp_attribs.astype(np.float32)
+ angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
+ angle = angle.reshape((-1, 1))
+ rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
+ rois = paddle.to_tensor(rotated_rois)
+
+ content_feats = self.pooling(feat_maps, rois)
+ content_feats = content_feats.reshape([content_feats.shape[0], -1])
+ node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
+
+ sorted_dist_inds = np.argsort(distance_matrix, axis=1)
+ (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
+ pivots_local_graphs) = self.generate_local_graphs(sorted_dist_inds,
+ node_feats)
+
+ none_flag = False
+ return none_flag, (local_graphs_node_feat, adjacent_matrices,
+ pivots_knn_inds, pivots_local_graphs, text_comps)
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index a94d223a1e67999cbd4363d83771d141a33c668d..f5e89a5b80f665d77833ffedaa2c141a3022f25d 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -27,11 +27,12 @@ def build_neck(config):
from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN
from .ct_fpn import CTFPN
+ from .fpn_unet import FPN_UNet
from .rf_adaptor import RFAdaptor
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
- 'RFAdaptor'
+ 'RFAdaptor', 'FPN_UNet'
]
module_name = config.pop('name')
diff --git a/ppocr/modeling/necks/fpn_unet.py b/ppocr/modeling/necks/fpn_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..34e94a8b50532cfbbfea1cecdba6cfb0d5a239cd
--- /dev/null
+++ b/ppocr/modeling/necks/fpn_unet.py
@@ -0,0 +1,97 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/necks/fpn_unet.py
+"""
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class UpBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ assert isinstance(in_channels, int)
+ assert isinstance(out_channels, int)
+
+ self.conv1x1 = nn.Conv2D(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.conv3x3 = nn.Conv2D(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.deconv = nn.Conv2DTranspose(
+ out_channels, out_channels, kernel_size=4, stride=2, padding=1)
+
+ def forward(self, x):
+ x = F.relu(self.conv1x1(x))
+ x = F.relu(self.conv3x3(x))
+ x = self.deconv(x)
+ return x
+
+
+class FPN_UNet(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ assert len(in_channels) == 4
+ assert isinstance(out_channels, int)
+ self.out_channels = out_channels
+
+ blocks_out_channels = [out_channels] + [
+ min(out_channels * 2**i, 256) for i in range(4)
+ ]
+ blocks_in_channels = [blocks_out_channels[1]] + [
+ in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
+ ] + [in_channels[3]]
+
+ self.up4 = nn.Conv2DTranspose(
+ blocks_in_channels[4],
+ blocks_out_channels[4],
+ kernel_size=4,
+ stride=2,
+ padding=1)
+ self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
+ self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
+ self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
+ self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0])
+
+ def forward(self, x):
+ """
+ Args:
+ x (list[Tensor] | tuple[Tensor]): A list of four tensors of shape
+ :math:`(N, C_i, H_i, W_i)`, representing C2, C3, C4, C5
+ features respectively. :math:`C_i` should matches the number in
+ ``in_channels``.
+
+ Returns:
+ Tensor: Shape :math:`(N, C, H, W)` where :math:`H=4H_0` and
+ :math:`W=4W_0`.
+ """
+ c2, c3, c4, c5 = x
+
+ x = F.relu(self.up4(c5))
+
+ x = paddle.concat([x, c4], axis=1)
+ x = F.relu(self.up_block3(x))
+
+ x = paddle.concat([x, c3], axis=1)
+ x = F.relu(self.up_block2(x))
+
+ x = paddle.concat([x, c2], axis=1)
+ x = F.relu(self.up_block1(x))
+
+ x = self.up_block0(x)
+ return x
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index b5715967b01aefb13e8b1cc9654924483376b72b..3a09030b25461029d9160699dc591eaedab9e0db 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -36,6 +36,7 @@ from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, Di
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess
+from .drrg_postprocess import DRRGPostprocess
def build_post_process(config, global_config=None):
@@ -49,7 +50,8 @@ def build_post_process(config, global_config=None):
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
- 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
+ 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
+ 'RFLLabelDecode', 'DRRGPostprocess'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/drrg_postprocess.py b/ppocr/postprocess/drrg_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..353081c9d4d0fa1d04d995c84445445767276cc8
--- /dev/null
+++ b/ppocr/postprocess/drrg_postprocess.py
@@ -0,0 +1,326 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/postprocess/drrg_postprocessor.py
+"""
+
+import functools
+import operator
+
+import numpy as np
+import paddle
+from numpy.linalg import norm
+import cv2
+
+
+class Node:
+ def __init__(self, ind):
+ self.__ind = ind
+ self.__links = set()
+
+ @property
+ def ind(self):
+ return self.__ind
+
+ @property
+ def links(self):
+ return set(self.__links)
+
+ def add_link(self, link_node):
+ self.__links.add(link_node)
+ link_node.__links.add(self)
+
+
+def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
+ assert edges.ndim == 2
+ assert edges.shape[1] == 2
+ assert edges.shape[0] == scores.shape[0]
+ assert text_comps.ndim == 2
+ assert isinstance(edge_len_thr, float)
+
+ edges = np.sort(edges, axis=1)
+ score_dict = {}
+ for i, edge in enumerate(edges):
+ if text_comps is not None:
+ box1 = text_comps[edge[0], :8].reshape(4, 2)
+ box2 = text_comps[edge[1], :8].reshape(4, 2)
+ center1 = np.mean(box1, axis=0)
+ center2 = np.mean(box2, axis=0)
+ distance = norm(center1 - center2)
+ if distance > edge_len_thr:
+ scores[i] = 0
+ if (edge[0], edge[1]) in score_dict:
+ score_dict[edge[0], edge[1]] = 0.5 * (
+ score_dict[edge[0], edge[1]] + scores[i])
+ else:
+ score_dict[edge[0], edge[1]] = scores[i]
+
+ nodes = np.sort(np.unique(edges.flatten()))
+ mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
+ mapping[nodes] = np.arange(nodes.shape[0])
+ order_inds = mapping[edges]
+ vertices = [Node(node) for node in nodes]
+ for ind in order_inds:
+ vertices[ind[0]].add_link(vertices[ind[1]])
+
+ return vertices, score_dict
+
+
+def connected_components(nodes, score_dict, link_thr):
+ assert isinstance(nodes, list)
+ assert all([isinstance(node, Node) for node in nodes])
+ assert isinstance(score_dict, dict)
+ assert isinstance(link_thr, float)
+
+ clusters = []
+ nodes = set(nodes)
+ while nodes:
+ node = nodes.pop()
+ cluster = {node}
+ node_queue = [node]
+ while node_queue:
+ node = node_queue.pop(0)
+ neighbors = set([
+ neighbor for neighbor in node.links
+ if score_dict[tuple(sorted([node.ind, neighbor.ind]))] >=
+ link_thr
+ ])
+ neighbors.difference_update(cluster)
+ nodes.difference_update(neighbors)
+ cluster.update(neighbors)
+ node_queue.extend(neighbors)
+ clusters.append(list(cluster))
+ return clusters
+
+
+def clusters2labels(clusters, num_nodes):
+ assert isinstance(clusters, list)
+ assert all([isinstance(cluster, list) for cluster in clusters])
+ assert all(
+ [isinstance(node, Node) for cluster in clusters for node in cluster])
+ assert isinstance(num_nodes, int)
+
+ node_labels = np.zeros(num_nodes)
+ for cluster_ind, cluster in enumerate(clusters):
+ for node in cluster:
+ node_labels[node.ind] = cluster_ind
+ return node_labels
+
+
+def remove_single(text_comps, comp_pred_labels):
+ assert text_comps.ndim == 2
+ assert text_comps.shape[0] == comp_pred_labels.shape[0]
+
+ single_flags = np.zeros_like(comp_pred_labels)
+ pred_labels = np.unique(comp_pred_labels)
+ for label in pred_labels:
+ current_label_flag = (comp_pred_labels == label)
+ if np.sum(current_label_flag) == 1:
+ single_flags[np.where(current_label_flag)[0][0]] = 1
+ keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
+ filtered_text_comps = text_comps[keep_ind, :]
+ filtered_labels = comp_pred_labels[keep_ind]
+
+ return filtered_text_comps, filtered_labels
+
+
+def norm2(point1, point2):
+ return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
+
+
+def min_connect_path(points):
+ assert isinstance(points, list)
+ assert all([isinstance(point, list) for point in points])
+ assert all([isinstance(coord, int) for point in points for coord in point])
+
+ points_queue = points.copy()
+ shortest_path = []
+ current_edge = [[], []]
+
+ edge_dict0 = {}
+ edge_dict1 = {}
+ current_edge[0] = points_queue[0]
+ current_edge[1] = points_queue[0]
+ points_queue.remove(points_queue[0])
+ while points_queue:
+ for point in points_queue:
+ length0 = norm2(point, current_edge[0])
+ edge_dict0[length0] = [point, current_edge[0]]
+ length1 = norm2(current_edge[1], point)
+ edge_dict1[length1] = [current_edge[1], point]
+ key0 = min(edge_dict0.keys())
+ key1 = min(edge_dict1.keys())
+
+ if key0 <= key1:
+ start = edge_dict0[key0][0]
+ end = edge_dict0[key0][1]
+ shortest_path.insert(0, [points.index(start), points.index(end)])
+ points_queue.remove(start)
+ current_edge[0] = start
+ else:
+ start = edge_dict1[key1][0]
+ end = edge_dict1[key1][1]
+ shortest_path.append([points.index(start), points.index(end)])
+ points_queue.remove(end)
+ current_edge[1] = end
+
+ edge_dict0 = {}
+ edge_dict1 = {}
+
+ shortest_path = functools.reduce(operator.concat, shortest_path)
+ shortest_path = sorted(set(shortest_path), key=shortest_path.index)
+
+ return shortest_path
+
+
+def in_contour(cont, point):
+ x, y = point
+ is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
+ return is_inner
+
+
+def fix_corner(top_line, bot_line, start_box, end_box):
+ assert isinstance(top_line, list)
+ assert all(isinstance(point, list) for point in top_line)
+ assert isinstance(bot_line, list)
+ assert all(isinstance(point, list) for point in bot_line)
+ assert start_box.shape == end_box.shape == (4, 2)
+
+ contour = np.array(top_line + bot_line[::-1])
+ start_left_mid = (start_box[0] + start_box[3]) / 2
+ start_right_mid = (start_box[1] + start_box[2]) / 2
+ end_left_mid = (end_box[0] + end_box[3]) / 2
+ end_right_mid = (end_box[1] + end_box[2]) / 2
+ if not in_contour(contour, start_left_mid):
+ top_line.insert(0, start_box[0].tolist())
+ bot_line.insert(0, start_box[3].tolist())
+ elif not in_contour(contour, start_right_mid):
+ top_line.insert(0, start_box[1].tolist())
+ bot_line.insert(0, start_box[2].tolist())
+ if not in_contour(contour, end_left_mid):
+ top_line.append(end_box[0].tolist())
+ bot_line.append(end_box[3].tolist())
+ elif not in_contour(contour, end_right_mid):
+ top_line.append(end_box[1].tolist())
+ bot_line.append(end_box[2].tolist())
+ return top_line, bot_line
+
+
+def comps2boundaries(text_comps, comp_pred_labels):
+ assert text_comps.ndim == 2
+ assert len(text_comps) == len(comp_pred_labels)
+ boundaries = []
+ if len(text_comps) < 1:
+ return boundaries
+ for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
+ cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
+ text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
+ (-1, 4, 2)).astype(np.int32)
+ score = np.mean(text_comps[cluster_comp_inds, -1])
+
+ if text_comp_boxes.shape[0] < 1:
+ continue
+
+ elif text_comp_boxes.shape[0] > 1:
+ centers = np.mean(text_comp_boxes, axis=1).astype(np.int32).tolist()
+ shortest_path = min_connect_path(centers)
+ text_comp_boxes = text_comp_boxes[shortest_path]
+ top_line = np.mean(
+ text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
+ bot_line = np.mean(
+ text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
+ top_line, bot_line = fix_corner(
+ top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1])
+ boundary_points = top_line + bot_line[::-1]
+
+ else:
+ top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
+ bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
+ boundary_points = top_line + bot_line
+
+ boundary = [p for coord in boundary_points for p in coord] + [score]
+ boundaries.append(boundary)
+
+ return boundaries
+
+
+class DRRGPostprocess(object):
+ """Merge text components and construct boundaries of text instances.
+
+ Args:
+ link_thr (float): The edge score threshold.
+ """
+
+ def __init__(self, link_thr, **kwargs):
+ assert isinstance(link_thr, float)
+ self.link_thr = link_thr
+
+ def __call__(self, preds, shape_list):
+ """
+ Args:
+ edges (ndarray): The edge array of shape N * 2, each row is a node
+ index pair that makes up an edge in graph.
+ scores (ndarray): The edge score array of shape (N,).
+ text_comps (ndarray): The text components.
+
+ Returns:
+ List[list[float]]: The predicted boundaries of text instances.
+ """
+ edges, scores, text_comps = preds
+ if edges is not None:
+ if isinstance(edges, paddle.Tensor):
+ edges = edges.numpy()
+ if isinstance(scores, paddle.Tensor):
+ scores = scores.numpy()
+ if isinstance(text_comps, paddle.Tensor):
+ text_comps = text_comps.numpy()
+ assert len(edges) == len(scores)
+ assert text_comps.ndim == 2
+ assert text_comps.shape[1] == 9
+
+ vertices, score_dict = graph_propagation(edges, scores, text_comps)
+ clusters = connected_components(vertices, score_dict, self.link_thr)
+ pred_labels = clusters2labels(clusters, text_comps.shape[0])
+ text_comps, pred_labels = remove_single(text_comps, pred_labels)
+ boundaries = comps2boundaries(text_comps, pred_labels)
+ else:
+ boundaries = []
+
+ boundaries, scores = self.resize_boundary(
+ boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
+ boxes_batch = [dict(points=boundaries, scores=scores)]
+ return boxes_batch
+
+ def resize_boundary(self, boundaries, scale_factor):
+ """Rescale boundaries via scale_factor.
+
+ Args:
+ boundaries (list[list[float]]): The boundary list. Each boundary
+ with size 2k+1 with k>=4.
+ scale_factor(ndarray): The scale factor of size (4,).
+
+ Returns:
+ boundaries (list[list[float]]): The scaled boundaries.
+ """
+ boxes = []
+ scores = []
+ for b in boundaries:
+ sz = len(b)
+ scores.append(b[-1])
+ b = (np.array(b[:sz - 1]) *
+ (np.tile(scale_factor[:2], int(
+ (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
+ boxes.append(np.array(b).reshape([-1, 2]))
+ return boxes, scores
diff --git a/requirements.txt b/requirements.txt
index 7a018b50952a876b4839eabbd72fac09d2bbd73b..d795e06f0f76ee7ae009772ae8ff2bdbc321a16a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,4 +15,5 @@ premailer
openpyxl
attrdict
Polygon3
+lanms-neo==1.0.2
PyMuPDF==1.18.7
diff --git a/tools/program.py b/tools/program.py
index 8fcf68866d0a7a4ec4848d3b61bf4519130d7cdd..5d2bd5bfb034940e3bec802b5e7041c8e82a9271 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -220,7 +220,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
- "RobustScanner", "RFL"
+ "RobustScanner", "RFL", 'DRRG'
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
@@ -629,7 +629,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
- 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
+ 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG'
]
if use_xpu: