diff --git a/configs/det/det_r18_vd_ct.yml b/configs/det/det_r18_vd_ct.yml
new file mode 100644
index 0000000000000000000000000000000000000000..42922dfd22c0e49d20d50534c76fedae16b27a4a
--- /dev/null
+++ b/configs/det/det_r18_vd_ct.yml
@@ -0,0 +1,107 @@
+Global:
+ use_gpu: true
+ epoch_num: 600
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/det_ct/
+ save_epoch_step: 10
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0,1000]
+ cal_metric_during_train: False
+ pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_en/img623.jpg
+ save_res_path: ./output/det_ct/predicts_ct.txt
+
+Architecture:
+ model_type: det
+ algorithm: CT
+ Transform:
+ Backbone:
+ name: ResNet_vd
+ layers: 18
+ Neck:
+ name: CTFPN
+ Head:
+ name: CT_Head
+ in_channels: 512
+ hidden_dim: 128
+ num_classes: 3
+
+Loss:
+ name: CTLoss
+
+Optimizer:
+ name: Adam
+ lr: #PolynomialDecay
+ name: Linear
+ learning_rate: 0.001
+ end_lr: 0.
+ epochs: 600
+ step_each_epoch: 1254
+ power: 0.9
+
+PostProcess:
+ name: CTPostProcess
+ box_type: poly
+
+Metric:
+ name: CTMetric
+ main_indicator: f_score
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/total_text/train
+ label_file_list:
+ - ./train_data/total_text/train/train.txt
+ ratio_list: [1.0]
+ transforms:
+ - DecodeImage:
+ img_mode: RGB
+ channel_first: False
+ - CTLabelEncode: # Class handling label
+ - RandomScale:
+ - MakeShrink:
+ - GroupRandomHorizontalFlip:
+ - GroupRandomRotate:
+ - GroupRandomCropPadding:
+ - MakeCentripetalShift:
+ - ColorJitter:
+ brightness: 0.125
+ saturation: 0.5
+ - ToCHWImage:
+ - NormalizeImage:
+ - KeepKeys:
+ keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list
+ loader:
+ shuffle: True
+ drop_last: True
+ batch_size_per_card: 4
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/total_text/test
+ label_file_list:
+ - ./train_data/total_text/test/test.txt
+ ratio_list: [1.0]
+ transforms:
+ - DecodeImage:
+ img_mode: RGB
+ channel_first: False
+ - CTLabelEncode: # Class handling label
+ - ScaleAlignedShort:
+ - NormalizeImage:
+ order: 'hwc'
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 2
diff --git a/doc/doc_ch/algorithm_det_ct.md b/doc/doc_ch/algorithm_det_ct.md
new file mode 100644
index 0000000000000000000000000000000000000000..ea3522b7bf3c2dc17ef4f645bc47738477f07cf1
--- /dev/null
+++ b/doc/doc_ch/algorithm_det_ct.md
@@ -0,0 +1,95 @@
+# CT
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945)
+> Tao Sheng, Jie Chen, Zhouhui Lian
+> NeurIPS, 2021
+
+
+在Total-Text文本检测公开数据集上,算法复现效果如下:
+
+|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
+| --- | --- | --- | --- | --- | --- | --- |
+|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)|
+
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+CT模型使用Total-Text文本检测公开数据集训练得到,数据集下载可参考 [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset), 我们将标签文件转成了paddleocr格式,转换好的标签文件下载参考[train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [text.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt)。
+
+请参考[文本检测训练教程](./detection.md)。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将CT文本检测训练过程中保存的模型,转换成inference model。以基于Resnet18_vd骨干网络,在Total-Text英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar) ),可以使用如下命令进行转换:
+
+```shell
+python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct
+```
+
+CT文本检测模型推理,可以执行如下命令:
+
+```shell
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT"
+```
+
+可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
+
+![](../imgs_results/det_res_img623_ct.jpg)
+
+
+
+### 4.2 C++推理
+
+暂不支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@inproceedings{sheng2021centripetaltext,
+ title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection},
+ author={Tao Sheng and Jie Chen and Zhouhui Lian},
+ booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
+ year={2021}
+}
+```
diff --git a/doc/doc_en/algorithm_det_ct_en.md b/doc/doc_en/algorithm_det_ct_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..d56b3fc6b3353bacb1f26fba3873ba5276b10c8b
--- /dev/null
+++ b/doc/doc_en/algorithm_det_ct_en.md
@@ -0,0 +1,96 @@
+# CT
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945)
+> Tao Sheng, Jie Chen, Zhouhui Lian
+> NeurIPS, 2021
+
+
+On the Total-Text dataset, the text detection result is as follows:
+
+|Model|Backbone|Configuration|Precision|Recall|Hmean|Download|
+| --- | --- | --- | --- | --- | --- | --- |
+|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)|
+
+
+
+## 2. Environment
+Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md).
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+
+The above CT model is trained using the Total-Text text detection public dataset. For the download of the dataset, please refer to [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset). PaddleOCR format annotation download link [train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [test.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt).
+
+
+Please refer to [text detection training tutorial](./detection_en.md). PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models.
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, convert the model saved in the CT text detection training process into an inference model. Taking the model based on the Resnet18_vd backbone network and trained on the Total Text English dataset as example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)), you can use the following command to convert:
+
+```shell
+python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct
+```
+
+CT text detection model inference, you can execute the following command:
+
+```shell
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT"
+```
+
+The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
+
+![](../imgs_results/det_res_img623_ct.jpg)
+
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@inproceedings{sheng2021centripetaltext,
+ title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection},
+ author={Tao Sheng and Jie Chen and Zhouhui Lian},
+ booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
+ year={2021}
+}
+```
diff --git a/doc/imgs_results/det_res_img623_ct.jpg b/doc/imgs_results/det_res_img623_ct.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2c5f57d96cca896c70d9e0d33ba80a0177a8aeb9
Binary files /dev/null and b/doc/imgs_results/det_res_img623_ct.jpg differ
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 102f48fcc19e59d9f8ffb0ad496f54cc64864f7d..863988cccfa9d9f2c865a444410d4245687f49ee 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -43,6 +43,7 @@ from .vqa import *
from .fce_aug import *
from .fce_targets import FCENetTargets
+from .ct_process import *
def transform(data, ops=None):
diff --git a/ppocr/data/imaug/ct_process.py b/ppocr/data/imaug/ct_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..59715090036e1020800950b02b9ea06ab5c8d4c2
--- /dev/null
+++ b/ppocr/data/imaug/ct_process.py
@@ -0,0 +1,355 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import random
+import pyclipper
+import paddle
+
+import numpy as np
+import Polygon as plg
+import scipy.io as scio
+
+from PIL import Image
+import paddle.vision.transforms as transforms
+
+
+class RandomScale():
+ def __init__(self, short_size=640, **kwargs):
+ self.short_size = short_size
+
+ def scale_aligned(self, img, scale):
+ oh, ow = img.shape[0:2]
+ h = int(oh * scale + 0.5)
+ w = int(ow * scale + 0.5)
+ if h % 32 != 0:
+ h = h + (32 - h % 32)
+ if w % 32 != 0:
+ w = w + (32 - w % 32)
+ img = cv2.resize(img, dsize=(w, h))
+ factor_h = h / oh
+ factor_w = w / ow
+ return img, factor_h, factor_w
+
+ def __call__(self, data):
+ img = data['image']
+
+ h, w = img.shape[0:2]
+ random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])
+ scale = (np.random.choice(random_scale) * self.short_size) / min(h, w)
+ img, factor_h, factor_w = self.scale_aligned(img, scale)
+
+ data['scale_factor'] = (factor_w, factor_h)
+ data['image'] = img
+ return data
+
+
+class MakeShrink():
+ def __init__(self, kernel_scale=0.7, **kwargs):
+ self.kernel_scale = kernel_scale
+
+ def dist(self, a, b):
+ return np.linalg.norm((a - b), ord=2, axis=0)
+
+ def perimeter(self, bbox):
+ peri = 0.0
+ for i in range(bbox.shape[0]):
+ peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
+ return peri
+
+ def shrink(self, bboxes, rate, max_shr=20):
+ rate = rate * rate
+ shrinked_bboxes = []
+ for bbox in bboxes:
+ area = plg.Polygon(bbox).area()
+ peri = self.perimeter(bbox)
+
+ try:
+ pco = pyclipper.PyclipperOffset()
+ pco.AddPath(bbox, pyclipper.JT_ROUND,
+ pyclipper.ET_CLOSEDPOLYGON)
+ offset = min(
+ int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
+
+ shrinked_bbox = pco.Execute(-offset)
+ if len(shrinked_bbox) == 0:
+ shrinked_bboxes.append(bbox)
+ continue
+
+ shrinked_bbox = np.array(shrinked_bbox[0])
+ if shrinked_bbox.shape[0] <= 2:
+ shrinked_bboxes.append(bbox)
+ continue
+
+ shrinked_bboxes.append(shrinked_bbox)
+ except Exception as e:
+ shrinked_bboxes.append(bbox)
+
+ return shrinked_bboxes
+
+ def __call__(self, data):
+ img = data['image']
+ bboxes = data['polys']
+ words = data['texts']
+ scale_factor = data['scale_factor']
+
+ gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w
+ training_mask = np.ones(img.shape[0:2], dtype='uint8')
+ training_mask_distance = np.ones(img.shape[0:2], dtype='uint8')
+
+ for i in range(len(bboxes)):
+ bboxes[i] = np.reshape(bboxes[i] * (
+ [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
+ (bboxes[i].shape[0] // 2, 2)).astype('int32')
+
+ for i in range(len(bboxes)):
+ #different value for different bbox
+ cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
+
+ # set training mask to 0
+ cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)
+
+ # for not accurate annotation, use training_mask_distance
+ if words[i] == '###' or words[i] == '???':
+ cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1)
+
+ # make shrink
+ gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8')
+ kernel_bboxes = self.shrink(bboxes, self.kernel_scale)
+ for i in range(len(bboxes)):
+ cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1,
+ -1)
+
+ # for training mask, kernel and background= 1, box region=0
+ if words[i] != '###' and words[i] != '???':
+ cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1)
+
+ gt_kernel = gt_kernel_instance.copy()
+ # for gt_kernel, kernel = 1
+ gt_kernel[gt_kernel > 0] = 1
+
+ # shrink 2 times
+ tmp1 = gt_kernel_instance.copy()
+ erode_kernel = np.ones((3, 3), np.uint8)
+ tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1)
+ tmp2 = tmp1.copy()
+ tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1)
+
+ # compute text region
+ gt_kernel_inner = tmp1 - tmp2
+
+ # gt_instance: text instance, bg=0, diff word use diff value
+ # training_mask: text instance mask, word=0,kernel and bg=1
+ # gt_kernel_instance: text kernel instance, bg=0, diff word use diff value
+ # gt_kernel: text_kernel, bg=0,diff word use same value
+ # gt_kernel_inner: text kernel reference
+ # training_mask_distance: word without anno = 0, else 1
+
+ data['image'] = [
+ img, gt_instance, training_mask, gt_kernel_instance, gt_kernel,
+ gt_kernel_inner, training_mask_distance
+ ]
+ return data
+
+
+class GroupRandomHorizontalFlip():
+ def __init__(self, p=0.5, **kwargs):
+ self.p = p
+
+ def __call__(self, data):
+ imgs = data['image']
+
+ if random.random() < self.p:
+ for i in range(len(imgs)):
+ imgs[i] = np.flip(imgs[i], axis=1).copy()
+ data['image'] = imgs
+ return data
+
+
+class GroupRandomRotate():
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ imgs = data['image']
+
+ max_angle = 10
+ angle = random.random() * 2 * max_angle - max_angle
+ for i in range(len(imgs)):
+ img = imgs[i]
+ w, h = img.shape[:2]
+ rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
+ img_rotation = cv2.warpAffine(
+ img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST)
+ imgs[i] = img_rotation
+
+ data['image'] = imgs
+ return data
+
+
+class GroupRandomCropPadding():
+ def __init__(self, target_size=(640, 640), **kwargs):
+ self.target_size = target_size
+
+ def __call__(self, data):
+ imgs = data['image']
+
+ h, w = imgs[0].shape[0:2]
+ t_w, t_h = self.target_size
+ p_w, p_h = self.target_size
+ if w == t_w and h == t_h:
+ return data
+
+ t_h = t_h if t_h < h else h
+ t_w = t_w if t_w < w else w
+
+ if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
+ # make sure to crop the text region
+ tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
+ tl[tl < 0] = 0
+ br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
+ br[br < 0] = 0
+ br[0] = min(br[0], h - t_h)
+ br[1] = min(br[1], w - t_w)
+
+ i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
+ j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
+ else:
+ i = random.randint(0, h - t_h) if h - t_h > 0 else 0
+ j = random.randint(0, w - t_w) if w - t_w > 0 else 0
+
+ n_imgs = []
+ for idx in range(len(imgs)):
+ if len(imgs[idx].shape) == 3:
+ s3_length = int(imgs[idx].shape[-1])
+ img = imgs[idx][i:i + t_h, j:j + t_w, :]
+ img_p = cv2.copyMakeBorder(
+ img,
+ 0,
+ p_h - t_h,
+ 0,
+ p_w - t_w,
+ borderType=cv2.BORDER_CONSTANT,
+ value=tuple(0 for i in range(s3_length)))
+ else:
+ img = imgs[idx][i:i + t_h, j:j + t_w]
+ img_p = cv2.copyMakeBorder(
+ img,
+ 0,
+ p_h - t_h,
+ 0,
+ p_w - t_w,
+ borderType=cv2.BORDER_CONSTANT,
+ value=(0, ))
+ n_imgs.append(img_p)
+
+ data['image'] = n_imgs
+ return data
+
+
+class MakeCentripetalShift():
+ def __init__(self, **kwargs):
+ pass
+
+ def jaccard(self, As, Bs):
+ A = As.shape[0] # small
+ B = Bs.shape[0] # large
+
+ dis = np.sqrt(
+ np.sum((As[:, np.newaxis, :].repeat(
+ B, axis=1) - Bs[np.newaxis, :, :].repeat(
+ A, axis=0))**2,
+ axis=-1))
+
+ ind = np.argmin(dis, axis=-1)
+
+ return ind
+
+ def __call__(self, data):
+ imgs = data['image']
+
+ img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \
+ imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6]
+
+ max_instance = np.max(gt_instance) # num bbox
+
+ # make centripetal shift
+ gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32)
+ for i in range(1, max_instance + 1):
+ # kernel_reference
+ ind = (gt_kernel_inner == i)
+
+ if np.sum(ind) == 0:
+ training_mask[gt_instance == i] = 0
+ training_mask_distance[gt_instance == i] = 0
+ continue
+
+ kpoints = np.array(np.where(ind)).transpose(
+ (1, 0))[:, ::-1].astype('float32')
+
+ ind = (gt_instance == i) * (gt_kernel_instance == 0)
+ if np.sum(ind) == 0:
+ continue
+ pixels = np.where(ind)
+
+ points = np.array(pixels).transpose(
+ (1, 0))[:, ::-1].astype('float32')
+
+ bbox_ind = self.jaccard(points, kpoints)
+
+ offset_gt = kpoints[bbox_ind] - points
+
+ gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1
+
+ img = Image.fromarray(img)
+ img = img.convert('RGB')
+
+ data["image"] = img
+ data["gt_kernel"] = gt_kernel.astype("int64")
+ data["training_mask"] = training_mask.astype("int64")
+ data["gt_instance"] = gt_instance.astype("int64")
+ data["gt_kernel_instance"] = gt_kernel_instance.astype("int64")
+ data["training_mask_distance"] = training_mask_distance.astype("int64")
+ data["gt_distance"] = gt_distance.astype("float32")
+
+ return data
+
+
+class ScaleAlignedShort():
+ def __init__(self, short_size=640, **kwargs):
+ self.short_size = short_size
+
+ def __call__(self, data):
+ img = data['image']
+
+ org_img_shape = img.shape
+
+ h, w = img.shape[0:2]
+ scale = self.short_size * 1.0 / min(h, w)
+ h = int(h * scale + 0.5)
+ w = int(w * scale + 0.5)
+ if h % 32 != 0:
+ h = h + (32 - h % 32)
+ if w % 32 != 0:
+ w = w + (32 - w % 32)
+ img = cv2.resize(img, dsize=(w, h))
+
+ new_img_shape = img.shape
+ img_shape = np.array(org_img_shape + new_img_shape)
+
+ data['shape'] = img_shape
+ data['image'] = img
+
+ return data
\ No newline at end of file
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 59cb9b8a253cf04244ebf83511ab412174487a53..dbfb93176cc782bedc8f7b33367b59046c4abec8 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1395,3 +1395,29 @@ class VLLabelEncode(BaseRecLabelEncode):
data['label_res'] = np.array(label_res)
data['label_sub'] = np.array(label_sub)
return data
+
+
+class CTLabelEncode(object):
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ label = data['label']
+
+ label = json.loads(label)
+ nBox = len(label)
+ boxes, txts = [], []
+ for bno in range(0, nBox):
+ box = label[bno]['points']
+ box = np.array(box)
+
+ boxes.append(box)
+ txt = label[bno]['transcription']
+ txts.append(txt)
+
+ if len(boxes) == 0:
+ return None
+
+ data['polys'] = boxes
+ data['texts'] = txts
+ return data
\ No newline at end of file
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 1a11778945c9d7b5f5519cd55473e8bf7790db2c..02525b3d50ad87509a6cba6fb2c1b00cb0add56e 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -25,6 +25,7 @@ from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss
+from .det_ct_loss import CTLoss
# rec loss
from .rec_ctc_loss import CTCLoss
@@ -68,7 +69,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss'
+ 'SLALoss', 'CTLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/det_ct_loss.py b/ppocr/losses/det_ct_loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..f48c95be4f84e2d8520363379b3061fa4245c105
--- /dev/null
+++ b/ppocr/losses/det_ct_loss.py
@@ -0,0 +1,276 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/shengtao96/CentripetalText/tree/main/models/loss
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+import numpy as np
+
+
+def ohem_single(score, gt_text, training_mask):
+ # online hard example mining
+
+ pos_num = int(paddle.sum(gt_text > 0.5)) - int(
+ paddle.sum((gt_text > 0.5) & (training_mask <= 0.5)))
+
+ if pos_num == 0:
+ # selected_mask = gt_text.copy() * 0 # may be not good
+ selected_mask = training_mask
+ selected_mask = paddle.cast(
+ selected_mask.reshape(
+ (1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
+ return selected_mask
+
+ neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5)))
+ neg_num = int(min(pos_num * 3, neg_num))
+
+ if neg_num == 0:
+ selected_mask = training_mask
+ selected_mask = paddle.cast(
+ selected_mask.reshape(
+ (1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
+ return selected_mask
+
+ # hard example
+ neg_score = score[(gt_text <= 0.5) & (training_mask > 0.5)]
+ neg_score_sorted = paddle.sort(-neg_score)
+ threshold = -neg_score_sorted[neg_num - 1]
+
+ selected_mask = ((score >= threshold) |
+ (gt_text > 0.5)) & (training_mask > 0.5)
+ selected_mask = paddle.cast(
+ selected_mask.reshape(
+ (1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
+ return selected_mask
+
+
+def ohem_batch(scores, gt_texts, training_masks):
+ selected_masks = []
+ for i in range(scores.shape[0]):
+ selected_masks.append(
+ ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
+ i, :, :]))
+
+ selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32")
+ return selected_masks
+
+
+def iou_single(a, b, mask, n_class):
+ EPS = 1e-6
+ valid = mask == 1
+ a = a[valid]
+ b = b[valid]
+ miou = []
+
+ # iou of each class
+ for i in range(n_class):
+ inter = paddle.cast(((a == i) & (b == i)), "float32")
+ union = paddle.cast(((a == i) | (b == i)), "float32")
+
+ miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
+ miou = sum(miou) / len(miou)
+ return miou
+
+
+def iou(a, b, mask, n_class=2, reduce=True):
+ batch_size = a.shape[0]
+
+ a = a.reshape((batch_size, -1))
+ b = b.reshape((batch_size, -1))
+ mask = mask.reshape((batch_size, -1))
+
+ iou = paddle.zeros((batch_size, ), dtype="float32")
+ for i in range(batch_size):
+ iou[i] = iou_single(a[i], b[i], mask[i], n_class)
+
+ if reduce:
+ iou = paddle.mean(iou)
+ return iou
+
+
+class DiceLoss(nn.Layer):
+ def __init__(self, loss_weight=1.0):
+ super(DiceLoss, self).__init__()
+ self.loss_weight = loss_weight
+
+ def forward(self, input, target, mask, reduce=True):
+ batch_size = input.shape[0]
+ input = F.sigmoid(input) # scale to 0-1
+
+ input = input.reshape((batch_size, -1))
+ target = paddle.cast(target.reshape((batch_size, -1)), "float32")
+ mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
+
+ input = input * mask
+ target = target * mask
+
+ a = paddle.sum(input * target, axis=1)
+ b = paddle.sum(input * input, axis=1) + 0.001
+ c = paddle.sum(target * target, axis=1) + 0.001
+ d = (2 * a) / (b + c)
+ loss = 1 - d
+
+ loss = self.loss_weight * loss
+
+ if reduce:
+ loss = paddle.mean(loss)
+
+ return loss
+
+
+class SmoothL1Loss(nn.Layer):
+ def __init__(self, beta=1.0, loss_weight=1.0):
+ super(SmoothL1Loss, self).__init__()
+ self.beta = beta
+ self.loss_weight = loss_weight
+
+ np_coord = np.zeros(shape=[640, 640, 2], dtype=np.int64)
+ for i in range(640):
+ for j in range(640):
+ np_coord[i, j, 0] = j
+ np_coord[i, j, 1] = i
+ np_coord = np_coord.reshape((-1, 2))
+
+ self.coord = self.create_parameter(
+ shape=[640 * 640, 2],
+ dtype="int32", # NOTE: not support "int64" before paddle 2.3.1
+ default_initializer=nn.initializer.Assign(value=np_coord))
+ self.coord.stop_gradient = True
+
+ def forward_single(self, input, target, mask, beta=1.0, eps=1e-6):
+ batch_size = input.shape[0]
+
+ diff = paddle.abs(input - target) * mask.unsqueeze(1)
+ loss = paddle.where(diff < beta, 0.5 * diff * diff / beta,
+ diff - 0.5 * beta)
+ loss = paddle.cast(loss.reshape((batch_size, -1)), "float32")
+ mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
+ loss = paddle.sum(loss, axis=-1)
+ loss = loss / (mask.sum(axis=-1) + eps)
+
+ return loss
+
+ def select_single(self, distance, gt_instance, gt_kernel_instance,
+ training_mask):
+
+ with paddle.no_grad():
+ # paddle 2.3.1, paddle.slice not support:
+ # distance[:, self.coord[:, 1], self.coord[:, 0]]
+ select_distance_list = []
+ for i in range(2):
+ tmp1 = distance[i, :]
+ tmp2 = tmp1[self.coord[:, 1], self.coord[:, 0]]
+ select_distance_list.append(tmp2.unsqueeze(0))
+ select_distance = paddle.concat(select_distance_list, axis=0)
+
+ off_points = paddle.cast(
+ self.coord, "float32") + 10 * select_distance.transpose((1, 0))
+
+ off_points = paddle.cast(off_points, "int64")
+ off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1)
+
+ selected_mask = (
+ gt_instance[self.coord[:, 1], self.coord[:, 0]] !=
+ gt_kernel_instance[off_points[:, 1], off_points[:, 0]])
+ selected_mask = paddle.cast(
+ selected_mask.reshape((1, -1, distance.shape[-1])), "int64")
+ selected_training_mask = selected_mask * training_mask
+
+ return selected_training_mask
+
+ def forward(self,
+ distances,
+ gt_instances,
+ gt_kernel_instances,
+ training_masks,
+ gt_distances,
+ reduce=True):
+
+ selected_training_masks = []
+ for i in range(distances.shape[0]):
+ selected_training_masks.append(
+ self.select_single(distances[i, :, :, :], gt_instances[i, :, :],
+ gt_kernel_instances[i, :, :], training_masks[
+ i, :, :]))
+ selected_training_masks = paddle.cast(
+ paddle.concat(selected_training_masks, 0), "float32")
+
+ loss = self.forward_single(distances, gt_distances,
+ selected_training_masks, self.beta)
+ loss = self.loss_weight * loss
+
+ with paddle.no_grad():
+ batch_size = distances.shape[0]
+ false_num = selected_training_masks.reshape((batch_size, -1))
+ false_num = false_num.sum(axis=-1)
+ total_num = paddle.cast(
+ training_masks.reshape((batch_size, -1)), "float32")
+ total_num = total_num.sum(axis=-1)
+ iou_text = (total_num - false_num) / (total_num + 1e-6)
+
+ if reduce:
+ loss = paddle.mean(loss)
+
+ return loss, iou_text
+
+
+class CTLoss(nn.Layer):
+ def __init__(self):
+ super(CTLoss, self).__init__()
+ self.kernel_loss = DiceLoss()
+ self.loc_loss = SmoothL1Loss(beta=0.1, loss_weight=0.05)
+
+ def forward(self, preds, batch):
+ imgs = batch[0]
+ out = preds['maps']
+ gt_kernels, training_masks, gt_instances, gt_kernel_instances, training_mask_distances, gt_distances = batch[
+ 1:]
+
+ kernels = out[:, 0, :, :]
+ distances = out[:, 1:, :, :]
+
+ # kernel loss
+ selected_masks = ohem_batch(kernels, gt_kernels, training_masks)
+
+ loss_kernel = self.kernel_loss(
+ kernels, gt_kernels, selected_masks, reduce=False)
+
+ iou_kernel = iou(paddle.cast((kernels > 0), "int64"),
+ gt_kernels,
+ training_masks,
+ reduce=False)
+ losses = dict(loss_kernels=loss_kernel, )
+
+ # loc loss
+ loss_loc, iou_text = self.loc_loss(
+ distances,
+ gt_instances,
+ gt_kernel_instances,
+ training_mask_distances,
+ gt_distances,
+ reduce=False)
+ losses.update(dict(loss_loc=loss_loc, ))
+
+ loss_all = loss_kernel + loss_loc
+ losses = {'loss': loss_all}
+
+ return losses
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 853647c06cf0519a0e049e14c16a0d3e26f9845b..a39d0a464f3f96b44d23cec55768223ca41311fa 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -31,12 +31,14 @@ from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
from .sr_metric import SRMetric
+from .ct_metric import CTMetric
+
def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
- 'VQAReTokenMetric', 'SRMetric'
+ 'VQAReTokenMetric', 'SRMetric', 'CTMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/ct_metric.py b/ppocr/metrics/ct_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7634230a23027a5dd5c32a7b8eb87ee4a229076
--- /dev/null
+++ b/ppocr/metrics/ct_metric.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from scipy import io
+import numpy as np
+
+from ppocr.utils.e2e_metric.Deteval import combine_results, get_score_C
+
+
+class CTMetric(object):
+ def __init__(self, main_indicator, delimiter='\t', **kwargs):
+ self.delimiter = delimiter
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def reset(self):
+ self.results = [] # clear results
+
+ def __call__(self, preds, batch, **kwargs):
+ # NOTE: only support bs=1 now, as the label length of different sample is Unequal
+ assert len(
+ preds) == 1, "CentripetalText test now only suuport batch_size=1."
+ label = batch[2]
+ text = batch[3]
+ pred = preds[0]['points']
+ result = get_score_C(label, text, pred)
+
+ self.results.append(result)
+
+ def get_metric(self):
+ """
+ Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
+ """
+ metrics = combine_results(self.results, rec_flag=False)
+ self.reset()
+ return metrics
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 0feda6c6e062fa314d97b8949d8545ed3305c22e..751757e5f176119688e2db47a68c514850b91823 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -23,6 +23,7 @@ def build_head(config):
from .det_pse_head import PSEHead
from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead
+ from .det_ct_head import CT_Head
# rec head
from .rec_ctc_head import CTCHead
@@ -52,7 +53,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
- 'VLHead', 'SLAHead', 'RobustScannerHead'
+ 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
]
#table head
diff --git a/ppocr/modeling/heads/det_ct_head.py b/ppocr/modeling/heads/det_ct_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..08e6719e8f0ade6887eb4ad7f44a2bc36ec132db
--- /dev/null
+++ b/ppocr/modeling/heads/det_ct_head.py
@@ -0,0 +1,69 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+import math
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.)
+
+
+class CT_Head(nn.Layer):
+ def __init__(self,
+ in_channels,
+ hidden_dim,
+ num_classes,
+ loss_kernel=None,
+ loss_loc=None):
+ super(CT_Head, self).__init__()
+ self.conv1 = nn.Conv2D(
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
+ self.bn1 = nn.BatchNorm2D(hidden_dim)
+ self.relu1 = nn.ReLU()
+
+ self.conv2 = nn.Conv2D(
+ hidden_dim, num_classes, kernel_size=1, stride=1, padding=0)
+
+ for m in self.sublayers():
+ if isinstance(m, nn.Conv2D):
+ n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
+ normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2D):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def _upsample(self, x, scale=1):
+ return F.upsample(x, scale_factor=scale, mode='bilinear')
+
+ def forward(self, f, targets=None):
+ out = self.conv1(f)
+ out = self.relu1(self.bn1(out))
+ out = self.conv2(out)
+
+ if self.training:
+ out = self._upsample(out, scale=4)
+ return {'maps': out}
+ else:
+ score = F.sigmoid(out[:, 0, :, :])
+ return {'maps': out, 'score': score}
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index e3ae2d6ef27821f592645a4ba945d3feeaa8cf8a..c7e8dd068b4a68e56b066ca8fa629644a8f302c6 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -26,13 +26,15 @@ def build_neck(config):
from .fce_fpn import FCEFPN
from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN
+ from .ct_fpn import CTFPN
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
- 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN'
+ 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
support_dict))
+
module_class = eval(module_name)(**config)
return module_class
diff --git a/ppocr/modeling/necks/ct_fpn.py b/ppocr/modeling/necks/ct_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee4d25e901b5b3093588571f0412a931eaf6f364
--- /dev/null
+++ b/ppocr/modeling/necks/ct_fpn.py
@@ -0,0 +1,185 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+import os
+import sys
+
+import math
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.)
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
+
+
+class Conv_BN_ReLU(nn.Layer):
+ def __init__(self,
+ in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=1,
+ padding=0):
+ super(Conv_BN_ReLU, self).__init__()
+ self.conv = nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias_attr=False)
+ self.bn = nn.BatchNorm2D(out_planes)
+ self.relu = nn.ReLU()
+
+ for m in self.sublayers():
+ if isinstance(m, nn.Conv2D):
+ n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
+ normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2D):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward(self, x):
+ return self.relu(self.bn(self.conv(x)))
+
+
+class FPEM(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super(FPEM, self).__init__()
+ planes = out_channels
+ self.dwconv3_1 = nn.Conv2D(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=planes,
+ bias_attr=False)
+ self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes)
+
+ self.dwconv2_1 = nn.Conv2D(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=planes,
+ bias_attr=False)
+ self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes)
+
+ self.dwconv1_1 = nn.Conv2D(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=planes,
+ bias_attr=False)
+ self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes)
+
+ self.dwconv2_2 = nn.Conv2D(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=planes,
+ bias_attr=False)
+ self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes)
+
+ self.dwconv3_2 = nn.Conv2D(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=planes,
+ bias_attr=False)
+ self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes)
+
+ self.dwconv4_2 = nn.Conv2D(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=planes,
+ bias_attr=False)
+ self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes)
+
+ def _upsample_add(self, x, y):
+ return F.upsample(x, scale_factor=2, mode='bilinear') + y
+
+ def forward(self, f1, f2, f3, f4):
+ # up-down
+ f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3)))
+ f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2)))
+ f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1)))
+
+ # down-up
+ f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1)))
+ f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2)))
+ f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3)))
+
+ return f1, f2, f3, f4
+
+
+class CTFPN(nn.Layer):
+ def __init__(self, in_channels, out_channel=128):
+ super(CTFPN, self).__init__()
+ self.out_channels = out_channel * 4
+
+ self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128)
+ self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128)
+ self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128)
+ self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128)
+
+ self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
+ self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
+
+ def _upsample(self, x, scale=1):
+ return F.upsample(x, scale_factor=scale, mode='bilinear')
+
+ def forward(self, f):
+ # # reduce channel
+ f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160
+ f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80
+ f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40
+ f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20
+
+ # FPEM
+ f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4)
+ f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1)
+
+ # FFM
+ f1 = f1_1 + f1_2
+ f2 = f2_1 + f2_2
+ f3 = f3_1 + f3_2
+ f4 = f4_1 + f4_2
+
+ f2 = self._upsample(f2, scale=2)
+ f3 = self._upsample(f3, scale=4)
+ f4 = self._upsample(f4, scale=8)
+ ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160
+ return ff
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 8f41a005f5b90e7edf11fad80b9b7eac89257160..35b7a6800da422264a796da14236ae8a484c30d9 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -35,6 +35,7 @@ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess,
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
+from .ct_postprocess import CTPostProcess
def build_post_process(config, global_config=None):
@@ -48,7 +49,7 @@ def build_post_process(config, global_config=None):
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
- 'VLLabelDecode', 'PicoDetPostProcess'
+ 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/ct_postprocess.py b/ppocr/postprocess/ct_postprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ab90be24d65888339698a5abe2ed692ceaab4c7
--- /dev/null
+++ b/ppocr/postprocess/ct_postprocess.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refered from:
+https://github.com/shengtao96/CentripetalText/blob/main/test.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import os.path as osp
+import numpy as np
+import cv2
+import paddle
+import pyclipper
+
+
+class CTPostProcess(object):
+ """
+ The post process for Centripetal Text (CT).
+ """
+
+ def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs):
+ self.min_score = min_score
+ self.min_area = min_area
+ self.box_type = box_type
+
+ self.coord = np.zeros((2, 300, 300), dtype=np.int32)
+ for i in range(300):
+ for j in range(300):
+ self.coord[0, i, j] = j
+ self.coord[1, i, j] = i
+
+ def __call__(self, preds, batch):
+ outs = preds['maps']
+ out_scores = preds['score']
+
+ if isinstance(outs, paddle.Tensor):
+ outs = outs.numpy()
+ if isinstance(out_scores, paddle.Tensor):
+ out_scores = out_scores.numpy()
+
+ batch_size = outs.shape[0]
+ boxes_batch = []
+ for idx in range(batch_size):
+ bboxes = []
+ scores = []
+
+ img_shape = batch[idx]
+
+ org_img_size = img_shape[:3]
+ img_shape = img_shape[3:]
+ img_size = img_shape[:2]
+
+ out = np.expand_dims(outs[idx], axis=0)
+ outputs = dict()
+
+ score = np.expand_dims(out_scores[idx], axis=0)
+
+ kernel = out[:, 0, :, :] > 0.2
+ loc = out[:, 1:, :, :].astype("float32")
+
+ score = score[0].astype(np.float32)
+ kernel = kernel[0].astype(np.uint8)
+ loc = loc[0].astype(np.float32)
+
+ label_num, label_kernel = cv2.connectedComponents(
+ kernel, connectivity=4)
+
+ for i in range(1, label_num):
+ ind = (label_kernel == i)
+ if ind.sum(
+ ) < 10: # pixel number less than 10, treated as background
+ label_kernel[ind] = 0
+
+ label = np.zeros_like(label_kernel)
+ h, w = label_kernel.shape
+ pixels = self.coord[:, :h, :w].reshape(2, -1)
+ points = pixels.transpose([1, 0]).astype(np.float32)
+
+ off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T
+ ).astype(np.int32)
+ off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1)
+ off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1)
+
+ label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1],
+ off_points[:, 0]]
+ label[label_kernel > 0] = label_kernel[label_kernel > 0]
+
+ score_pocket = [0.0]
+ for i in range(1, label_num):
+ ind = (label_kernel == i)
+ if ind.sum() == 0:
+ score_pocket.append(0.0)
+ continue
+ score_i = np.mean(score[ind])
+ score_pocket.append(score_i)
+
+ label_num = np.max(label) + 1
+ label = cv2.resize(
+ label, (img_size[1], img_size[0]),
+ interpolation=cv2.INTER_NEAREST)
+
+ scale = (float(org_img_size[1]) / float(img_size[1]),
+ float(org_img_size[0]) / float(img_size[0]))
+
+ for i in range(1, label_num):
+ ind = (label == i)
+ points = np.array(np.where(ind)).transpose((1, 0))
+
+ if points.shape[0] < self.min_area:
+ continue
+
+ score_i = score_pocket[i]
+ if score_i < self.min_score:
+ continue
+
+ if self.box_type == 'rect':
+ rect = cv2.minAreaRect(points[:, ::-1])
+ bbox = cv2.boxPoints(rect) * scale
+ z = bbox.mean(0)
+ bbox = z + (bbox - z) * 0.85
+ elif self.box_type == 'poly':
+ binary = np.zeros(label.shape, dtype='uint8')
+ binary[ind] = 1
+ try:
+ _, contours, _ = cv2.findContours(
+ binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ except BaseException:
+ contours, _ = cv2.findContours(
+ binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ bbox = contours[0] * scale
+
+ bbox = bbox.astype('int32')
+ bboxes.append(bbox.reshape(-1, 2))
+ scores.append(score_i)
+
+ boxes_batch.append({'points': bboxes})
+
+ return boxes_batch
diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py
index 45567a7dd2d82b6c583abd4a4eabef52974be081..6ce56eda2aa9f38fdc712d49ae64945c558b418d 100755
--- a/ppocr/utils/e2e_metric/Deteval.py
+++ b/ppocr/utils/e2e_metric/Deteval.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import numpy as np
import scipy.io as io
+import Polygon as plg
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
@@ -269,7 +271,124 @@ def get_socre_B(gt_dir, img_id, pred_dict):
return single_data
-def combine_results(all_data):
+def get_score_C(gt_label, text, pred_bboxes):
+ """
+ get score for CentripetalText (CT) prediction.
+ """
+
+ def gt_reading_mod(gt_label, text):
+ """This helper reads groundtruths from mat files"""
+ groundtruths = []
+ nbox = len(gt_label)
+ for i in range(nbox):
+ label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
+ groundtruths.append(label)
+
+ return groundtruths
+
+ def get_union(pD, pG):
+ areaA = pD.area()
+ areaB = pG.area()
+ return areaA + areaB - get_intersection(pD, pG)
+
+ def get_intersection(pD, pG):
+ pInt = pD & pG
+ if len(pInt) == 0:
+ return 0
+ return pInt.area()
+
+ def detection_filtering(detections, groundtruths, threshold=0.5):
+ for gt in groundtruths:
+ point_num = gt['points'].shape[1] // 2
+ if gt['transcription'] == '###' and (point_num > 1):
+ gt_p = np.array(gt['points']).reshape(point_num,
+ 2).astype('int32')
+ gt_p = plg.Polygon(gt_p)
+
+ for det_id, detection in enumerate(detections):
+ det_y = detection[0::2]
+ det_x = detection[1::2]
+
+ det_p = np.concatenate((np.array(det_x), np.array(det_y)))
+ det_p = det_p.reshape(2, -1).transpose()
+ det_p = plg.Polygon(det_p)
+
+ try:
+ det_gt_iou = get_intersection(det_p,
+ gt_p) / det_p.area()
+ except:
+ print(det_x, det_y, gt_p)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_p, gt_p):
+ """
+ sigma = inter_area / gt_area
+ """
+ if gt_p.area() == 0.:
+ return 0
+ return get_intersection(det_p, gt_p) / gt_p.area()
+
+ def tau_calculation(det_p, gt_p):
+ """
+ tau = inter_area / det_area
+ """
+ if det_p.area() == 0.:
+ return 0
+ return get_intersection(det_p, gt_p) / det_p.area()
+
+ detections = []
+
+ for item in pred_bboxes:
+ detections.append(item[:, ::-1].reshape(-1))
+
+ groundtruths = gt_reading_mod(gt_label, text)
+
+ detections = detection_filtering(
+ detections, groundtruths) # filters detections overlapping with DC area
+
+ for idx in range(len(groundtruths) - 1, -1, -1):
+ #NOTE: source code use 'orin' to indicate '#', here we use 'anno',
+ # which may cause slight drop in fscore, about 0.12
+ if groundtruths[idx]['transcription'] == '###':
+ groundtruths.pop(idx)
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ point_num = gt['points'].shape[1] // 2
+
+ gt_p = np.array(gt['points']).reshape(point_num,
+ 2).astype('int32')
+ gt_p = plg.Polygon(gt_p)
+
+ det_y = detection[0::2]
+ det_x = detection[1::2]
+
+ det_p = np.concatenate((np.array(det_x), np.array(det_y)))
+
+ det_p = det_p.reshape(2, -1).transpose()
+ det_p = plg.Polygon(det_p)
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
+ gt_p)
+ local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
+
+ data = {}
+ data['sigma'] = local_sigma_table
+ data['global_tau'] = local_tau_table
+ data['global_pred_str'] = ''
+ data['global_gt_str'] = ''
+ return data
+
+
+def combine_results(all_data, rec_flag=True):
tr = 0.7
tp = 0.6
fsc_k = 0.8
@@ -278,6 +397,7 @@ def combine_results(all_data):
global_tau = []
global_pred_str = []
global_gt_str = []
+
for data in all_data:
global_sigma.append(data['sigma'])
global_tau.append(data['global_tau'])
@@ -294,7 +414,7 @@ def combine_results(all_data):
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy):
+ gt_flag, det_flag, idy, rec_flag):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
@@ -328,14 +448,15 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
- 0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
+ if rec_flag:
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][matched_det_id[0]
+ .tolist()[0]]
+ if pred_str_cur == gt_str_cur:
hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
@@ -343,7 +464,7 @@ def combine_results(all_data):
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy):
+ gt_flag, det_flag, idy, rec_flag):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
@@ -374,28 +495,30 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
+ if rec_flag:
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
+ if rec_flag:
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
@@ -409,7 +532,7 @@ def combine_results(all_data):
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy):
+ gt_flag, det_flag, idy, rec_flag):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
@@ -440,6 +563,30 @@ def combine_results(all_data):
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
+ if rec_flag:
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[
+ 0].tolist()[idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower(
+ ):
+ hit_str_num += 1
+ break
+ # recg end
+ elif (np.sum(local_tau_table[qualified_sigma_candidates,
+ det_id]) >= tp):
+ det_flag[0, det_id] = 1
+ gt_flag[0, qualified_sigma_candidates] = 1
+ # recg start
+ if rec_flag:
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
@@ -454,27 +601,7 @@ def combine_results(all_data):
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
- break
- # recg end
- elif (np.sum(local_tau_table[qualified_sigma_candidates,
- det_id]) >= tp):
- det_flag[0, det_id] = 1
- gt_flag[0, qualified_sigma_candidates] = 1
- # recg start
- pred_str_cur = global_pred_str[idy][det_id]
- gt_len = len(qualified_sigma_candidates[0])
- for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
- if ele_gt_id not in global_gt_str[idy]:
- continue
- gt_str_cur = global_gt_str[idy][ele_gt_id]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- break
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- break
+ break
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
@@ -504,7 +631,7 @@ def combine_results(all_data):
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx)
+ gt_flag, det_flag, idx, rec_flag)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
@@ -512,14 +639,14 @@ def combine_results(all_data):
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx)
+ gt_flag, det_flag, idx, rec_flag)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx)
+ gt_flag, det_flag, idx, rec_flag)
hit_str_count += hit_str_num
try:
diff --git a/requirements.txt b/requirements.txt
index 2c0741a065dacf1fb637865e8f9796a611876d60..43cd8c1b082768ebad44a5cf58fc31980ebfe891 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,3 +14,4 @@ lxml
premailer
openpyxl
attrdict
+Polygon3
diff --git a/test_tipc/configs/det_r18_ct/train_infer_python.txt b/test_tipc/configs/det_r18_ct/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5933fdbeed762a73324fbfb5a4113a390926e7ea
--- /dev/null
+++ b/test_tipc/configs/det_r18_ct/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:det_r18_ct
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./train_data/total_text/test/rgb/
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c configs/det/det_r18_vd_ct.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
+quant_export:null
+fpgm_export:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c configs/det/det_r18_vd_ct.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/det_r18_vd_ct/best_accuracy
+infer_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o
+infer_quant:False
+inference:tools/infer/predict_det.py
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+--save_log_path:null
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 1185dec59e334802cc17b905b98d512e2498497b..a4ba31928bba4a00a560461392f7011244af5e0c 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -274,6 +274,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./train_data/ && tar xf XFUND.tar
cd ../
fi
+ if [ ${model_name} == "det_r18_ct" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ct_tipc/total_text_lite2.tar --no-check-certificate
+ cd ./train_data && tar xf total_text_lite2.tar && ln -s total_text_lite2 total_text && cd ../
+ fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 9f5c480d3c55367a02eacb48bed6ae3d38282f05..00fa2e9b7fafd949c59a0eebd43f2f88ae717320 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -127,6 +127,9 @@ class TextDetector(object):
postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_fce_box_type
+ elif self.det_algorithm == "CT":
+ pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
+ postprocess_params['name'] = 'CTPostProcess'
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
@@ -253,6 +256,9 @@ class TextDetector(object):
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
preds['level_{}'.format(i)] = output
+ elif self.det_algorithm == "CT":
+ preds['maps'] = outputs[0]
+ preds['score'] = outputs[1]
else:
raise NotImplementedError
@@ -260,7 +266,7 @@ class TextDetector(object):
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
- self.det_algorithm in ["PSE", "FCE"] and
+ self.det_algorithm in ["PSE", "FCE", "CT"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
diff --git a/tools/program.py b/tools/program.py
index c91e66fd7f1d42482ee3b5c2c39911cc24d966e1..9117d51b95b343c46982f212d4e5faa069b7b44a 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -625,7 +625,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
- 'Gestalt', 'SLANet', 'RobustScanner'
+ 'Gestalt', 'SLANet', 'RobustScanner', 'CT'
]
if use_xpu:
diff --git a/tools/train.py b/tools/train.py
index d0f200189e34265b3c080ac9e25eb80d29c705b7..970a52624af7b2831d88956f857cd4271086bcca 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
+
use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -138,7 +139,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
-
+
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
@@ -146,7 +147,7 @@ def main(config, device, logger, vdl_writer):
use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
- amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
+ amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
@@ -161,20 +162,24 @@ def main(config, device, logger, vdl_writer):
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if amp_level == "O2":
model, optimizer = paddle.amp.decorate(
- models=model, optimizers=optimizer, level=amp_level, master_weight=True)
+ models=model,
+ optimizers=optimizer,
+ level=amp_level,
+ master_weight=True)
else:
scaler = None
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
-
+
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list)
+ eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
+ amp_level, amp_custom_black_list)
def test_reader(config, device, logger):
diff --git a/train.sh b/train.sh
index 4225470cb9f545b874e5f806af22405895e8f6c7..6fa04ea3febe8982016a35d83f119c0a483e3bb8 100644
--- a/train.sh
+++ b/train.sh
@@ -1,2 +1,2 @@
# recommended paddle.__version__ == 2.0.0
-python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
+python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
\ No newline at end of file