diff --git a/configs/det/det_r18_vd_ct.yml b/configs/det/det_r18_vd_ct.yml new file mode 100644 index 0000000000000000000000000000000000000000..42922dfd22c0e49d20d50534c76fedae16b27a4a --- /dev/null +++ b/configs/det/det_r18_vd_ct.yml @@ -0,0 +1,107 @@ +Global: + use_gpu: true + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/det_ct/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations + eval_batch_step: [0,1000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img623.jpg + save_res_path: ./output/det_ct/predicts_ct.txt + +Architecture: + model_type: det + algorithm: CT + Transform: + Backbone: + name: ResNet_vd + layers: 18 + Neck: + name: CTFPN + Head: + name: CT_Head + in_channels: 512 + hidden_dim: 128 + num_classes: 3 + +Loss: + name: CTLoss + +Optimizer: + name: Adam + lr: #PolynomialDecay + name: Linear + learning_rate: 0.001 + end_lr: 0. + epochs: 600 + step_each_epoch: 1254 + power: 0.9 + +PostProcess: + name: CTPostProcess + box_type: poly + +Metric: + name: CTMetric + main_indicator: f_score + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/total_text/train + label_file_list: + - ./train_data/total_text/train/train.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: RGB + channel_first: False + - CTLabelEncode: # Class handling label + - RandomScale: + - MakeShrink: + - GroupRandomHorizontalFlip: + - GroupRandomRotate: + - GroupRandomCropPadding: + - MakeCentripetalShift: + - ColorJitter: + brightness: 0.125 + saturation: 0.5 + - ToCHWImage: + - NormalizeImage: + - KeepKeys: + keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list + loader: + shuffle: True + drop_last: True + batch_size_per_card: 4 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/total_text/test + label_file_list: + - ./train_data/total_text/test/test.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: RGB + channel_first: False + - CTLabelEncode: # Class handling label + - ScaleAlignedShort: + - NormalizeImage: + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 2 diff --git a/doc/doc_ch/algorithm_det_ct.md b/doc/doc_ch/algorithm_det_ct.md new file mode 100644 index 0000000000000000000000000000000000000000..ea3522b7bf3c2dc17ef4f645bc47738477f07cf1 --- /dev/null +++ b/doc/doc_ch/algorithm_det_ct.md @@ -0,0 +1,95 @@ +# CT + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945) +> Tao Sheng, Jie Chen, Zhouhui Lian +> NeurIPS, 2021 + + +在Total-Text文本检测公开数据集上,算法复现效果如下: + +|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接| +| --- | --- | --- | --- | --- | --- | --- | +|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +CT模型使用Total-Text文本检测公开数据集训练得到,数据集下载可参考 [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset), 我们将标签文件转成了paddleocr格式,转换好的标签文件下载参考[train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [text.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt)。 + +请参考[文本检测训练教程](./detection.md)。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。 + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将CT文本检测训练过程中保存的模型,转换成inference model。以基于Resnet18_vd骨干网络,在Total-Text英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar) ),可以使用如下命令进行转换: + +```shell +python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct +``` + +CT文本检测模型推理,可以执行如下命令: + +```shell +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT" +``` + +可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](../imgs_results/det_res_img623_ct.jpg) + + + +### 4.2 C++推理 + +暂不支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@inproceedings{sheng2021centripetaltext, + title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection}, + author={Tao Sheng and Jie Chen and Zhouhui Lian}, + booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, + year={2021} +} +``` diff --git a/doc/doc_en/algorithm_det_ct_en.md b/doc/doc_en/algorithm_det_ct_en.md new file mode 100644 index 0000000000000000000000000000000000000000..d56b3fc6b3353bacb1f26fba3873ba5276b10c8b --- /dev/null +++ b/doc/doc_en/algorithm_det_ct_en.md @@ -0,0 +1,96 @@ +# CT + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945) +> Tao Sheng, Jie Chen, Zhouhui Lian +> NeurIPS, 2021 + + +On the Total-Text dataset, the text detection result is as follows: + +|Model|Backbone|Configuration|Precision|Recall|Hmean|Download| +| --- | --- | --- | --- | --- | --- | --- | +|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)| + + + +## 2. Environment +Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md). + + + +## 3. Model Training / Evaluation / Prediction + + +The above CT model is trained using the Total-Text text detection public dataset. For the download of the dataset, please refer to [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset). PaddleOCR format annotation download link [train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [test.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt). + + +Please refer to [text detection training tutorial](./detection_en.md). PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models. + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, convert the model saved in the CT text detection training process into an inference model. Taking the model based on the Resnet18_vd backbone network and trained on the Total Text English dataset as example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)), you can use the following command to convert: + +```shell +python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct +``` + +CT text detection model inference, you can execute the following command: + +```shell +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT" +``` + +The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: + +![](../imgs_results/det_res_img623_ct.jpg) + + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@inproceedings{sheng2021centripetaltext, + title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection}, + author={Tao Sheng and Jie Chen and Zhouhui Lian}, + booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, + year={2021} +} +``` diff --git a/doc/imgs_results/det_res_img623_ct.jpg b/doc/imgs_results/det_res_img623_ct.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c5f57d96cca896c70d9e0d33ba80a0177a8aeb9 Binary files /dev/null and b/doc/imgs_results/det_res_img623_ct.jpg differ diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 102f48fcc19e59d9f8ffb0ad496f54cc64864f7d..863988cccfa9d9f2c865a444410d4245687f49ee 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -43,6 +43,7 @@ from .vqa import * from .fce_aug import * from .fce_targets import FCENetTargets +from .ct_process import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/ct_process.py b/ppocr/data/imaug/ct_process.py new file mode 100644 index 0000000000000000000000000000000000000000..59715090036e1020800950b02b9ea06ab5c8d4c2 --- /dev/null +++ b/ppocr/data/imaug/ct_process.py @@ -0,0 +1,355 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import random +import pyclipper +import paddle + +import numpy as np +import Polygon as plg +import scipy.io as scio + +from PIL import Image +import paddle.vision.transforms as transforms + + +class RandomScale(): + def __init__(self, short_size=640, **kwargs): + self.short_size = short_size + + def scale_aligned(self, img, scale): + oh, ow = img.shape[0:2] + h = int(oh * scale + 0.5) + w = int(ow * scale + 0.5) + if h % 32 != 0: + h = h + (32 - h % 32) + if w % 32 != 0: + w = w + (32 - w % 32) + img = cv2.resize(img, dsize=(w, h)) + factor_h = h / oh + factor_w = w / ow + return img, factor_h, factor_w + + def __call__(self, data): + img = data['image'] + + h, w = img.shape[0:2] + random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]) + scale = (np.random.choice(random_scale) * self.short_size) / min(h, w) + img, factor_h, factor_w = self.scale_aligned(img, scale) + + data['scale_factor'] = (factor_w, factor_h) + data['image'] = img + return data + + +class MakeShrink(): + def __init__(self, kernel_scale=0.7, **kwargs): + self.kernel_scale = kernel_scale + + def dist(self, a, b): + return np.linalg.norm((a - b), ord=2, axis=0) + + def perimeter(self, bbox): + peri = 0.0 + for i in range(bbox.shape[0]): + peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) + return peri + + def shrink(self, bboxes, rate, max_shr=20): + rate = rate * rate + shrinked_bboxes = [] + for bbox in bboxes: + area = plg.Polygon(bbox).area() + peri = self.perimeter(bbox) + + try: + pco = pyclipper.PyclipperOffset() + pco.AddPath(bbox, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + offset = min( + int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) + + shrinked_bbox = pco.Execute(-offset) + if len(shrinked_bbox) == 0: + shrinked_bboxes.append(bbox) + continue + + shrinked_bbox = np.array(shrinked_bbox[0]) + if shrinked_bbox.shape[0] <= 2: + shrinked_bboxes.append(bbox) + continue + + shrinked_bboxes.append(shrinked_bbox) + except Exception as e: + shrinked_bboxes.append(bbox) + + return shrinked_bboxes + + def __call__(self, data): + img = data['image'] + bboxes = data['polys'] + words = data['texts'] + scale_factor = data['scale_factor'] + + gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w + training_mask = np.ones(img.shape[0:2], dtype='uint8') + training_mask_distance = np.ones(img.shape[0:2], dtype='uint8') + + for i in range(len(bboxes)): + bboxes[i] = np.reshape(bboxes[i] * ( + [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)), + (bboxes[i].shape[0] // 2, 2)).astype('int32') + + for i in range(len(bboxes)): + #different value for different bbox + cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) + + # set training mask to 0 + cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) + + # for not accurate annotation, use training_mask_distance + if words[i] == '###' or words[i] == '???': + cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1) + + # make shrink + gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8') + kernel_bboxes = self.shrink(bboxes, self.kernel_scale) + for i in range(len(bboxes)): + cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1, + -1) + + # for training mask, kernel and background= 1, box region=0 + if words[i] != '###' and words[i] != '???': + cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1) + + gt_kernel = gt_kernel_instance.copy() + # for gt_kernel, kernel = 1 + gt_kernel[gt_kernel > 0] = 1 + + # shrink 2 times + tmp1 = gt_kernel_instance.copy() + erode_kernel = np.ones((3, 3), np.uint8) + tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1) + tmp2 = tmp1.copy() + tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1) + + # compute text region + gt_kernel_inner = tmp1 - tmp2 + + # gt_instance: text instance, bg=0, diff word use diff value + # training_mask: text instance mask, word=0,kernel and bg=1 + # gt_kernel_instance: text kernel instance, bg=0, diff word use diff value + # gt_kernel: text_kernel, bg=0,diff word use same value + # gt_kernel_inner: text kernel reference + # training_mask_distance: word without anno = 0, else 1 + + data['image'] = [ + img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, + gt_kernel_inner, training_mask_distance + ] + return data + + +class GroupRandomHorizontalFlip(): + def __init__(self, p=0.5, **kwargs): + self.p = p + + def __call__(self, data): + imgs = data['image'] + + if random.random() < self.p: + for i in range(len(imgs)): + imgs[i] = np.flip(imgs[i], axis=1).copy() + data['image'] = imgs + return data + + +class GroupRandomRotate(): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + imgs = data['image'] + + max_angle = 10 + angle = random.random() * 2 * max_angle - max_angle + for i in range(len(imgs)): + img = imgs[i] + w, h = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) + img_rotation = cv2.warpAffine( + img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) + imgs[i] = img_rotation + + data['image'] = imgs + return data + + +class GroupRandomCropPadding(): + def __init__(self, target_size=(640, 640), **kwargs): + self.target_size = target_size + + def __call__(self, data): + imgs = data['image'] + + h, w = imgs[0].shape[0:2] + t_w, t_h = self.target_size + p_w, p_h = self.target_size + if w == t_w and h == t_h: + return data + + t_h = t_h if t_h < h else h + t_w = t_w if t_w < w else w + + if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: + # make sure to crop the text region + tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) + tl[tl < 0] = 0 + br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) + br[br < 0] = 0 + br[0] = min(br[0], h - t_h) + br[1] = min(br[1], w - t_w) + + i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + i = random.randint(0, h - t_h) if h - t_h > 0 else 0 + j = random.randint(0, w - t_w) if w - t_w > 0 else 0 + + n_imgs = [] + for idx in range(len(imgs)): + if len(imgs[idx].shape) == 3: + s3_length = int(imgs[idx].shape[-1]) + img = imgs[idx][i:i + t_h, j:j + t_w, :] + img_p = cv2.copyMakeBorder( + img, + 0, + p_h - t_h, + 0, + p_w - t_w, + borderType=cv2.BORDER_CONSTANT, + value=tuple(0 for i in range(s3_length))) + else: + img = imgs[idx][i:i + t_h, j:j + t_w] + img_p = cv2.copyMakeBorder( + img, + 0, + p_h - t_h, + 0, + p_w - t_w, + borderType=cv2.BORDER_CONSTANT, + value=(0, )) + n_imgs.append(img_p) + + data['image'] = n_imgs + return data + + +class MakeCentripetalShift(): + def __init__(self, **kwargs): + pass + + def jaccard(self, As, Bs): + A = As.shape[0] # small + B = Bs.shape[0] # large + + dis = np.sqrt( + np.sum((As[:, np.newaxis, :].repeat( + B, axis=1) - Bs[np.newaxis, :, :].repeat( + A, axis=0))**2, + axis=-1)) + + ind = np.argmin(dis, axis=-1) + + return ind + + def __call__(self, data): + imgs = data['image'] + + img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \ + imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6] + + max_instance = np.max(gt_instance) # num bbox + + # make centripetal shift + gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32) + for i in range(1, max_instance + 1): + # kernel_reference + ind = (gt_kernel_inner == i) + + if np.sum(ind) == 0: + training_mask[gt_instance == i] = 0 + training_mask_distance[gt_instance == i] = 0 + continue + + kpoints = np.array(np.where(ind)).transpose( + (1, 0))[:, ::-1].astype('float32') + + ind = (gt_instance == i) * (gt_kernel_instance == 0) + if np.sum(ind) == 0: + continue + pixels = np.where(ind) + + points = np.array(pixels).transpose( + (1, 0))[:, ::-1].astype('float32') + + bbox_ind = self.jaccard(points, kpoints) + + offset_gt = kpoints[bbox_ind] - points + + gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1 + + img = Image.fromarray(img) + img = img.convert('RGB') + + data["image"] = img + data["gt_kernel"] = gt_kernel.astype("int64") + data["training_mask"] = training_mask.astype("int64") + data["gt_instance"] = gt_instance.astype("int64") + data["gt_kernel_instance"] = gt_kernel_instance.astype("int64") + data["training_mask_distance"] = training_mask_distance.astype("int64") + data["gt_distance"] = gt_distance.astype("float32") + + return data + + +class ScaleAlignedShort(): + def __init__(self, short_size=640, **kwargs): + self.short_size = short_size + + def __call__(self, data): + img = data['image'] + + org_img_shape = img.shape + + h, w = img.shape[0:2] + scale = self.short_size * 1.0 / min(h, w) + h = int(h * scale + 0.5) + w = int(w * scale + 0.5) + if h % 32 != 0: + h = h + (32 - h % 32) + if w % 32 != 0: + w = w + (32 - w % 32) + img = cv2.resize(img, dsize=(w, h)) + + new_img_shape = img.shape + img_shape = np.array(org_img_shape + new_img_shape) + + data['shape'] = img_shape + data['image'] = img + + return data \ No newline at end of file diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 59cb9b8a253cf04244ebf83511ab412174487a53..dbfb93176cc782bedc8f7b33367b59046c4abec8 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1395,3 +1395,29 @@ class VLLabelEncode(BaseRecLabelEncode): data['label_res'] = np.array(label_res) data['label_sub'] = np.array(label_sub) return data + + +class CTLabelEncode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = data['label'] + + label = json.loads(label) + nBox = len(label) + boxes, txts = [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + box = np.array(box) + + boxes.append(box) + txt = label[bno]['transcription'] + txts.append(txt) + + if len(boxes) == 0: + return None + + data['polys'] = boxes + data['texts'] = txts + return data \ No newline at end of file diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 1a11778945c9d7b5f5519cd55473e8bf7790db2c..02525b3d50ad87509a6cba6fb2c1b00cb0add56e 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,6 +25,7 @@ from .det_east_loss import EASTLoss from .det_sast_loss import SASTLoss from .det_pse_loss import PSELoss from .det_fce_loss import FCELoss +from .det_ct_loss import CTLoss # rec loss from .rec_ctc_loss import CTCLoss @@ -68,7 +69,7 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss' + 'SLALoss', 'CTLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/det_ct_loss.py b/ppocr/losses/det_ct_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f48c95be4f84e2d8520363379b3061fa4245c105 --- /dev/null +++ b/ppocr/losses/det_ct_loss.py @@ -0,0 +1,276 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/shengtao96/CentripetalText/tree/main/models/loss +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +import numpy as np + + +def ohem_single(score, gt_text, training_mask): + # online hard example mining + + pos_num = int(paddle.sum(gt_text > 0.5)) - int( + paddle.sum((gt_text > 0.5) & (training_mask <= 0.5))) + + if pos_num == 0: + # selected_mask = gt_text.copy() * 0 # may be not good + selected_mask = training_mask + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5))) + neg_num = int(min(pos_num * 3, neg_num)) + + if neg_num == 0: + selected_mask = training_mask + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + # hard example + neg_score = score[(gt_text <= 0.5) & (training_mask > 0.5)] + neg_score_sorted = paddle.sort(-neg_score) + threshold = -neg_score_sorted[neg_num - 1] + + selected_mask = ((score >= threshold) | + (gt_text > 0.5)) & (training_mask > 0.5) + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + +def ohem_batch(scores, gt_texts, training_masks): + selected_masks = [] + for i in range(scores.shape[0]): + selected_masks.append( + ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[ + i, :, :])) + + selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32") + return selected_masks + + +def iou_single(a, b, mask, n_class): + EPS = 1e-6 + valid = mask == 1 + a = a[valid] + b = b[valid] + miou = [] + + # iou of each class + for i in range(n_class): + inter = paddle.cast(((a == i) & (b == i)), "float32") + union = paddle.cast(((a == i) | (b == i)), "float32") + + miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS)) + miou = sum(miou) / len(miou) + return miou + + +def iou(a, b, mask, n_class=2, reduce=True): + batch_size = a.shape[0] + + a = a.reshape((batch_size, -1)) + b = b.reshape((batch_size, -1)) + mask = mask.reshape((batch_size, -1)) + + iou = paddle.zeros((batch_size, ), dtype="float32") + for i in range(batch_size): + iou[i] = iou_single(a[i], b[i], mask[i], n_class) + + if reduce: + iou = paddle.mean(iou) + return iou + + +class DiceLoss(nn.Layer): + def __init__(self, loss_weight=1.0): + super(DiceLoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, input, target, mask, reduce=True): + batch_size = input.shape[0] + input = F.sigmoid(input) # scale to 0-1 + + input = input.reshape((batch_size, -1)) + target = paddle.cast(target.reshape((batch_size, -1)), "float32") + mask = paddle.cast(mask.reshape((batch_size, -1)), "float32") + + input = input * mask + target = target * mask + + a = paddle.sum(input * target, axis=1) + b = paddle.sum(input * input, axis=1) + 0.001 + c = paddle.sum(target * target, axis=1) + 0.001 + d = (2 * a) / (b + c) + loss = 1 - d + + loss = self.loss_weight * loss + + if reduce: + loss = paddle.mean(loss) + + return loss + + +class SmoothL1Loss(nn.Layer): + def __init__(self, beta=1.0, loss_weight=1.0): + super(SmoothL1Loss, self).__init__() + self.beta = beta + self.loss_weight = loss_weight + + np_coord = np.zeros(shape=[640, 640, 2], dtype=np.int64) + for i in range(640): + for j in range(640): + np_coord[i, j, 0] = j + np_coord[i, j, 1] = i + np_coord = np_coord.reshape((-1, 2)) + + self.coord = self.create_parameter( + shape=[640 * 640, 2], + dtype="int32", # NOTE: not support "int64" before paddle 2.3.1 + default_initializer=nn.initializer.Assign(value=np_coord)) + self.coord.stop_gradient = True + + def forward_single(self, input, target, mask, beta=1.0, eps=1e-6): + batch_size = input.shape[0] + + diff = paddle.abs(input - target) * mask.unsqueeze(1) + loss = paddle.where(diff < beta, 0.5 * diff * diff / beta, + diff - 0.5 * beta) + loss = paddle.cast(loss.reshape((batch_size, -1)), "float32") + mask = paddle.cast(mask.reshape((batch_size, -1)), "float32") + loss = paddle.sum(loss, axis=-1) + loss = loss / (mask.sum(axis=-1) + eps) + + return loss + + def select_single(self, distance, gt_instance, gt_kernel_instance, + training_mask): + + with paddle.no_grad(): + # paddle 2.3.1, paddle.slice not support: + # distance[:, self.coord[:, 1], self.coord[:, 0]] + select_distance_list = [] + for i in range(2): + tmp1 = distance[i, :] + tmp2 = tmp1[self.coord[:, 1], self.coord[:, 0]] + select_distance_list.append(tmp2.unsqueeze(0)) + select_distance = paddle.concat(select_distance_list, axis=0) + + off_points = paddle.cast( + self.coord, "float32") + 10 * select_distance.transpose((1, 0)) + + off_points = paddle.cast(off_points, "int64") + off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1) + + selected_mask = ( + gt_instance[self.coord[:, 1], self.coord[:, 0]] != + gt_kernel_instance[off_points[:, 1], off_points[:, 0]]) + selected_mask = paddle.cast( + selected_mask.reshape((1, -1, distance.shape[-1])), "int64") + selected_training_mask = selected_mask * training_mask + + return selected_training_mask + + def forward(self, + distances, + gt_instances, + gt_kernel_instances, + training_masks, + gt_distances, + reduce=True): + + selected_training_masks = [] + for i in range(distances.shape[0]): + selected_training_masks.append( + self.select_single(distances[i, :, :, :], gt_instances[i, :, :], + gt_kernel_instances[i, :, :], training_masks[ + i, :, :])) + selected_training_masks = paddle.cast( + paddle.concat(selected_training_masks, 0), "float32") + + loss = self.forward_single(distances, gt_distances, + selected_training_masks, self.beta) + loss = self.loss_weight * loss + + with paddle.no_grad(): + batch_size = distances.shape[0] + false_num = selected_training_masks.reshape((batch_size, -1)) + false_num = false_num.sum(axis=-1) + total_num = paddle.cast( + training_masks.reshape((batch_size, -1)), "float32") + total_num = total_num.sum(axis=-1) + iou_text = (total_num - false_num) / (total_num + 1e-6) + + if reduce: + loss = paddle.mean(loss) + + return loss, iou_text + + +class CTLoss(nn.Layer): + def __init__(self): + super(CTLoss, self).__init__() + self.kernel_loss = DiceLoss() + self.loc_loss = SmoothL1Loss(beta=0.1, loss_weight=0.05) + + def forward(self, preds, batch): + imgs = batch[0] + out = preds['maps'] + gt_kernels, training_masks, gt_instances, gt_kernel_instances, training_mask_distances, gt_distances = batch[ + 1:] + + kernels = out[:, 0, :, :] + distances = out[:, 1:, :, :] + + # kernel loss + selected_masks = ohem_batch(kernels, gt_kernels, training_masks) + + loss_kernel = self.kernel_loss( + kernels, gt_kernels, selected_masks, reduce=False) + + iou_kernel = iou(paddle.cast((kernels > 0), "int64"), + gt_kernels, + training_masks, + reduce=False) + losses = dict(loss_kernels=loss_kernel, ) + + # loc loss + loss_loc, iou_text = self.loc_loss( + distances, + gt_instances, + gt_kernel_instances, + training_mask_distances, + gt_distances, + reduce=False) + losses.update(dict(loss_loc=loss_loc, )) + + loss_all = loss_kernel + loss_loc + losses = {'loss': loss_all} + + return losses diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 853647c06cf0519a0e049e14c16a0d3e26f9845b..a39d0a464f3f96b44d23cec55768223ca41311fa 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -31,12 +31,14 @@ from .kie_metric import KIEMetric from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_re_metric import VQAReTokenMetric from .sr_metric import SRMetric +from .ct_metric import CTMetric + def build_metric(config): support_dict = [ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', - 'VQAReTokenMetric', 'SRMetric' + 'VQAReTokenMetric', 'SRMetric', 'CTMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/ct_metric.py b/ppocr/metrics/ct_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a7634230a23027a5dd5c32a7b8eb87ee4a229076 --- /dev/null +++ b/ppocr/metrics/ct_metric.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from scipy import io +import numpy as np + +from ppocr.utils.e2e_metric.Deteval import combine_results, get_score_C + + +class CTMetric(object): + def __init__(self, main_indicator, delimiter='\t', **kwargs): + self.delimiter = delimiter + self.main_indicator = main_indicator + self.reset() + + def reset(self): + self.results = [] # clear results + + def __call__(self, preds, batch, **kwargs): + # NOTE: only support bs=1 now, as the label length of different sample is Unequal + assert len( + preds) == 1, "CentripetalText test now only suuport batch_size=1." + label = batch[2] + text = batch[3] + pred = preds[0]['points'] + result = get_score_C(label, text, pred) + + self.results.append(result) + + def get_metric(self): + """ + Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' + """ + metrics = combine_results(self.results, rec_flag=False) + self.reset() + return metrics diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 0feda6c6e062fa314d97b8949d8545ed3305c22e..751757e5f176119688e2db47a68c514850b91823 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,6 +23,7 @@ def build_head(config): from .det_pse_head import PSEHead from .det_fce_head import FCEHead from .e2e_pg_head import PGHead + from .det_ct_head import CT_Head # rec head from .rec_ctc_head import CTCHead @@ -52,7 +53,7 @@ def build_head(config): 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', - 'VLHead', 'SLAHead', 'RobustScannerHead' + 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head' ] #table head diff --git a/ppocr/modeling/heads/det_ct_head.py b/ppocr/modeling/heads/det_ct_head.py new file mode 100644 index 0000000000000000000000000000000000000000..08e6719e8f0ade6887eb4ad7f44a2bc36ec132db --- /dev/null +++ b/ppocr/modeling/heads/det_ct_head.py @@ -0,0 +1,69 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + +import math +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +ones_ = Constant(value=1.) +zeros_ = Constant(value=0.) + + +class CT_Head(nn.Layer): + def __init__(self, + in_channels, + hidden_dim, + num_classes, + loss_kernel=None, + loss_loc=None): + super(CT_Head, self).__init__() + self.conv1 = nn.Conv2D( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(hidden_dim) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2D( + hidden_dim, num_classes, kernel_size=1, stride=1, padding=0) + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_ = Normal(mean=0.0, std=math.sqrt(2. / n)) + normal_(m.weight) + elif isinstance(m, nn.BatchNorm2D): + zeros_(m.bias) + ones_(m.weight) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def forward(self, f, targets=None): + out = self.conv1(f) + out = self.relu1(self.bn1(out)) + out = self.conv2(out) + + if self.training: + out = self._upsample(out, scale=4) + return {'maps': out} + else: + score = F.sigmoid(out[:, 0, :, :]) + return {'maps': out, 'score': score} diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index e3ae2d6ef27821f592645a4ba945d3feeaa8cf8a..c7e8dd068b4a68e56b066ca8fa629644a8f302c6 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -26,13 +26,15 @@ def build_neck(config): from .fce_fpn import FCEFPN from .pren_fpn import PRENFPN from .csp_pan import CSPPAN + from .ct_fpn import CTFPN support_dict = [ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', - 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN' + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN' ] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( support_dict)) + module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/necks/ct_fpn.py b/ppocr/modeling/necks/ct_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4d25e901b5b3093588571f0412a931eaf6f364 --- /dev/null +++ b/ppocr/modeling/necks/ct_fpn.py @@ -0,0 +1,185 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr +import os +import sys + +import math +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +ones_ = Constant(value=1.) +zeros_ = Constant(value=0.) + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..'))) + + +class Conv_BN_ReLU(nn.Layer): + def __init__(self, + in_planes, + out_planes, + kernel_size=1, + stride=1, + padding=0): + super(Conv_BN_ReLU, self).__init__() + self.conv = nn.Conv2D( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias_attr=False) + self.bn = nn.BatchNorm2D(out_planes) + self.relu = nn.ReLU() + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_ = Normal(mean=0.0, std=math.sqrt(2. / n)) + normal_(m.weight) + elif isinstance(m, nn.BatchNorm2D): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class FPEM(nn.Layer): + def __init__(self, in_channels, out_channels): + super(FPEM, self).__init__() + planes = out_channels + self.dwconv3_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv2_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv1_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv2_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes) + + self.dwconv3_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes) + + self.dwconv4_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes) + + def _upsample_add(self, x, y): + return F.upsample(x, scale_factor=2, mode='bilinear') + y + + def forward(self, f1, f2, f3, f4): + # up-down + f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3))) + f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2))) + f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1))) + + # down-up + f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1))) + f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2))) + f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3))) + + return f1, f2, f3, f4 + + +class CTFPN(nn.Layer): + def __init__(self, in_channels, out_channel=128): + super(CTFPN, self).__init__() + self.out_channels = out_channel * 4 + + self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128) + self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128) + self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128) + self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128) + + self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128) + self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def forward(self, f): + # # reduce channel + f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160 + f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80 + f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40 + f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20 + + # FPEM + f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4) + f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1) + + # FFM + f1 = f1_1 + f1_2 + f2 = f2_1 + f2_2 + f3 = f3_1 + f3_2 + f4 = f4_1 + f4_2 + + f2 = self._upsample(f2, scale=2) + f3 = self._upsample(f3, scale=4) + f4 = self._upsample(f4, scale=8) + ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160 + return ff diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 8f41a005f5b90e7edf11fad80b9b7eac89257160..35b7a6800da422264a796da14236ae8a484c30d9 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -35,6 +35,7 @@ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .picodet_postprocess import PicoDetPostProcess +from .ct_postprocess import CTPostProcess def build_post_process(config, global_config=None): @@ -48,7 +49,7 @@ def build_post_process(config, global_config=None): 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', - 'VLLabelDecode', 'PicoDetPostProcess' + 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/ct_postprocess.py b/ppocr/postprocess/ct_postprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..3ab90be24d65888339698a5abe2ed692ceaab4c7 --- /dev/null +++ b/ppocr/postprocess/ct_postprocess.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refered from: +https://github.com/shengtao96/CentripetalText/blob/main/test.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import numpy as np +import cv2 +import paddle +import pyclipper + + +class CTPostProcess(object): + """ + The post process for Centripetal Text (CT). + """ + + def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs): + self.min_score = min_score + self.min_area = min_area + self.box_type = box_type + + self.coord = np.zeros((2, 300, 300), dtype=np.int32) + for i in range(300): + for j in range(300): + self.coord[0, i, j] = j + self.coord[1, i, j] = i + + def __call__(self, preds, batch): + outs = preds['maps'] + out_scores = preds['score'] + + if isinstance(outs, paddle.Tensor): + outs = outs.numpy() + if isinstance(out_scores, paddle.Tensor): + out_scores = out_scores.numpy() + + batch_size = outs.shape[0] + boxes_batch = [] + for idx in range(batch_size): + bboxes = [] + scores = [] + + img_shape = batch[idx] + + org_img_size = img_shape[:3] + img_shape = img_shape[3:] + img_size = img_shape[:2] + + out = np.expand_dims(outs[idx], axis=0) + outputs = dict() + + score = np.expand_dims(out_scores[idx], axis=0) + + kernel = out[:, 0, :, :] > 0.2 + loc = out[:, 1:, :, :].astype("float32") + + score = score[0].astype(np.float32) + kernel = kernel[0].astype(np.uint8) + loc = loc[0].astype(np.float32) + + label_num, label_kernel = cv2.connectedComponents( + kernel, connectivity=4) + + for i in range(1, label_num): + ind = (label_kernel == i) + if ind.sum( + ) < 10: # pixel number less than 10, treated as background + label_kernel[ind] = 0 + + label = np.zeros_like(label_kernel) + h, w = label_kernel.shape + pixels = self.coord[:, :h, :w].reshape(2, -1) + points = pixels.transpose([1, 0]).astype(np.float32) + + off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T + ).astype(np.int32) + off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1) + off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1) + + label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1], + off_points[:, 0]] + label[label_kernel > 0] = label_kernel[label_kernel > 0] + + score_pocket = [0.0] + for i in range(1, label_num): + ind = (label_kernel == i) + if ind.sum() == 0: + score_pocket.append(0.0) + continue + score_i = np.mean(score[ind]) + score_pocket.append(score_i) + + label_num = np.max(label) + 1 + label = cv2.resize( + label, (img_size[1], img_size[0]), + interpolation=cv2.INTER_NEAREST) + + scale = (float(org_img_size[1]) / float(img_size[1]), + float(org_img_size[0]) / float(img_size[0])) + + for i in range(1, label_num): + ind = (label == i) + points = np.array(np.where(ind)).transpose((1, 0)) + + if points.shape[0] < self.min_area: + continue + + score_i = score_pocket[i] + if score_i < self.min_score: + continue + + if self.box_type == 'rect': + rect = cv2.minAreaRect(points[:, ::-1]) + bbox = cv2.boxPoints(rect) * scale + z = bbox.mean(0) + bbox = z + (bbox - z) * 0.85 + elif self.box_type == 'poly': + binary = np.zeros(label.shape, dtype='uint8') + binary[ind] = 1 + try: + _, contours, _ = cv2.findContours( + binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + except BaseException: + contours, _ = cv2.findContours( + binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + bbox = contours[0] * scale + + bbox = bbox.astype('int32') + bboxes.append(bbox.reshape(-1, 2)) + scores.append(score_i) + + boxes_batch.append({'points': bboxes}) + + return boxes_batch diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 45567a7dd2d82b6c583abd4a4eabef52974be081..6ce56eda2aa9f38fdc712d49ae64945c558b418d 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import numpy as np import scipy.io as io +import Polygon as plg from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area @@ -269,7 +271,124 @@ def get_socre_B(gt_dir, img_id, pred_dict): return single_data -def combine_results(all_data): +def get_score_C(gt_label, text, pred_bboxes): + """ + get score for CentripetalText (CT) prediction. + """ + + def gt_reading_mod(gt_label, text): + """This helper reads groundtruths from mat files""" + groundtruths = [] + nbox = len(gt_label) + for i in range(nbox): + label = {"transcription": text[i][0], "points": gt_label[i].numpy()} + groundtruths.append(label) + + return groundtruths + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def detection_filtering(detections, groundtruths, threshold=0.5): + for gt in groundtruths: + point_num = gt['points'].shape[1] // 2 + if gt['transcription'] == '###' and (point_num > 1): + gt_p = np.array(gt['points']).reshape(point_num, + 2).astype('int32') + gt_p = plg.Polygon(gt_p) + + for det_id, detection in enumerate(detections): + det_y = detection[0::2] + det_x = detection[1::2] + + det_p = np.concatenate((np.array(det_x), np.array(det_y))) + det_p = det_p.reshape(2, -1).transpose() + det_p = plg.Polygon(det_p) + + try: + det_gt_iou = get_intersection(det_p, + gt_p) / det_p.area() + except: + print(det_x, det_y, gt_p) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_p, gt_p): + """ + sigma = inter_area / gt_area + """ + if gt_p.area() == 0.: + return 0 + return get_intersection(det_p, gt_p) / gt_p.area() + + def tau_calculation(det_p, gt_p): + """ + tau = inter_area / det_area + """ + if det_p.area() == 0.: + return 0 + return get_intersection(det_p, gt_p) / det_p.area() + + detections = [] + + for item in pred_bboxes: + detections.append(item[:, ::-1].reshape(-1)) + + groundtruths = gt_reading_mod(gt_label, text) + + detections = detection_filtering( + detections, groundtruths) # filters detections overlapping with DC area + + for idx in range(len(groundtruths) - 1, -1, -1): + #NOTE: source code use 'orin' to indicate '#', here we use 'anno', + # which may cause slight drop in fscore, about 0.12 + if groundtruths[idx]['transcription'] == '###': + groundtruths.pop(idx) + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + point_num = gt['points'].shape[1] // 2 + + gt_p = np.array(gt['points']).reshape(point_num, + 2).astype('int32') + gt_p = plg.Polygon(gt_p) + + det_y = detection[0::2] + det_x = detection[1::2] + + det_p = np.concatenate((np.array(det_x), np.array(det_y))) + + det_p = det_p.reshape(2, -1).transpose() + det_p = plg.Polygon(det_p) + + local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, + gt_p) + local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p) + + data = {} + data['sigma'] = local_sigma_table + data['global_tau'] = local_tau_table + data['global_pred_str'] = '' + data['global_gt_str'] = '' + return data + + +def combine_results(all_data, rec_flag=True): tr = 0.7 tp = 0.6 fsc_k = 0.8 @@ -278,6 +397,7 @@ def combine_results(all_data): global_tau = [] global_pred_str = [] global_gt_str = [] + for data in all_data: global_sigma.append(data['sigma']) global_tau.append(data['global_tau']) @@ -294,7 +414,7 @@ def combine_results(all_data): def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for gt_id in range(num_gt): gt_matching_qualified_sigma_candidates = np.where( @@ -328,14 +448,15 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ - 0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0] + .tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end det_flag[0, matched_det_id] = 1 return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num @@ -343,7 +464,7 @@ def combine_results(all_data): def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for gt_id in range(num_gt): # skip the following if the groundtruth was matched @@ -374,28 +495,30 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end global_accumulative_recall = global_accumulative_recall + fsc_k @@ -409,7 +532,7 @@ def combine_results(all_data): def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for det_id in range(num_det): # skip the following if the detection was matched @@ -440,6 +563,30 @@ def combine_results(all_data): gt_flag[0, qualified_sigma_candidates] = 1 det_flag[0, det_id] = 1 # recg start + if rec_flag: + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[ + 0].tolist()[idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower( + ): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + if rec_flag: pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): @@ -454,27 +601,7 @@ def combine_results(all_data): else: if pred_str_cur.lower() == gt_str_cur.lower(): hit_str_num += 1 - break - # recg end - elif (np.sum(local_tau_table[qualified_sigma_candidates, - det_id]) >= tp): - det_flag[0, det_id] = 1 - gt_flag[0, qualified_sigma_candidates] = 1 - # recg start - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if ele_gt_id not in global_gt_str[idy]: - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break + break # recg end global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k @@ -504,7 +631,7 @@ def combine_results(all_data): gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num #######then check for one-to-many case########## @@ -512,14 +639,14 @@ def combine_results(all_data): gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num #######then check for many-to-one case########## local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num try: diff --git a/requirements.txt b/requirements.txt index 2c0741a065dacf1fb637865e8f9796a611876d60..43cd8c1b082768ebad44a5cf58fc31980ebfe891 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ lxml premailer openpyxl attrdict +Polygon3 diff --git a/test_tipc/configs/det_r18_ct/train_infer_python.txt b/test_tipc/configs/det_r18_ct/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..5933fdbeed762a73324fbfb5a4113a390926e7ea --- /dev/null +++ b/test_tipc/configs/det_r18_ct/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det_r18_ct +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./train_data/total_text/test/rgb/ +null:null +## +trainer:norm_train +norm_train:tools/train.py -c configs/det/det_r18_vd_ct.yml -o Global.print_batch_step=1 Train.loader.shuffle=false +quant_export:null +fpgm_export:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c configs/det/det_r18_vd_ct.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/det_r18_vd_ct/best_accuracy +infer_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o +infer_quant:False +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +--save_log_path:null +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] \ No newline at end of file diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 1185dec59e334802cc17b905b98d512e2498497b..a4ba31928bba4a00a560461392f7011244af5e0c 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -274,6 +274,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then cd ./train_data/ && tar xf XFUND.tar cd ../ fi + if [ ${model_name} == "det_r18_ct" ]; then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ct_tipc/total_text_lite2.tar --no-check-certificate + cd ./train_data && tar xf total_text_lite2.tar && ln -s total_text_lite2 total_text && cd ../ + fi elif [ ${MODE} = "whole_train_whole_infer" ];then wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 9f5c480d3c55367a02eacb48bed6ae3d38282f05..00fa2e9b7fafd949c59a0eebd43f2f88ae717320 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -127,6 +127,9 @@ class TextDetector(object): postprocess_params["beta"] = args.beta postprocess_params["fourier_degree"] = args.fourier_degree postprocess_params["box_type"] = args.det_fce_box_type + elif self.det_algorithm == "CT": + pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}} + postprocess_params['name'] = 'CTPostProcess' else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -253,6 +256,9 @@ class TextDetector(object): elif self.det_algorithm == 'FCE': for i, output in enumerate(outputs): preds['level_{}'.format(i)] = output + elif self.det_algorithm == "CT": + preds['maps'] = outputs[0] + preds['score'] = outputs[1] else: raise NotImplementedError @@ -260,7 +266,7 @@ class TextDetector(object): post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( - self.det_algorithm in ["PSE", "FCE"] and + self.det_algorithm in ["PSE", "FCE", "CT"] and self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: diff --git a/tools/program.py b/tools/program.py index c91e66fd7f1d42482ee3b5c2c39911cc24d966e1..9117d51b95b343c46982f212d4e5faa069b7b44a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -625,7 +625,7 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner' + 'Gestalt', 'SLANet', 'RobustScanner', 'CT' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index d0f200189e34265b3c080ac9e25eb80d29c705b7..970a52624af7b2831d88956f857cd4271086bcca 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) + use_sync_bn = config["Global"].get("use_sync_bn", False) if use_sync_bn: model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -138,7 +139,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( @@ -146,7 +147,7 @@ def main(config, device, logger, vdl_writer): use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') - amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) + amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -161,20 +162,24 @@ def main(config, device, logger, vdl_writer): use_dynamic_loss_scaling=use_dynamic_loss_scaling) if amp_level == "O2": model, optimizer = paddle.amp.decorate( - models=model, optimizers=optimizer, level=amp_level, master_weight=True) + models=model, + optimizers=optimizer, + level=amp_level, + master_weight=True) else: scaler = None # load pretrain model pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"]) - + if config['Global']['distributed']: model = paddle.DataParallel(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler, + amp_level, amp_custom_black_list) def test_reader(config, device, logger): diff --git a/train.sh b/train.sh index 4225470cb9f545b874e5f806af22405895e8f6c7..6fa04ea3febe8982016a35d83f119c0a483e3bb8 100644 --- a/train.sh +++ b/train.sh @@ -1,2 +1,2 @@ # recommended paddle.__version__ == 2.0.0 -python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml +python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml \ No newline at end of file