From 4cddec730733e5ce5b440f05d04f5219f1e045a7 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 27 Apr 2022 18:46:48 +0800 Subject: [PATCH] add rotnet code (#6065) * add rotnet code * add config * fix infer for ssl * rm unused code --- .../cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml | 99 +++++++++++++++++++ deploy/slim/quantization/export_model.py | 68 +++++++++---- deploy/slim/quantization/quant.py | 42 +++++++- doc/doc_ch/layout_datasets.md | 2 +- ppocr/data/__init__.py | 1 + ppocr/data/collate_fn.py | 14 +++ ppocr/data/imaug/__init__.py | 1 + ppocr/data/imaug/ssl_img_aug.py | 60 +++++++++++ ppocr/postprocess/cls_postprocess.py | 15 ++- tools/export_model.py | 8 +- tools/infer_cls.py | 2 + 11 files changed, 283 insertions(+), 29 deletions(-) create mode 100644 configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml create mode 100644 ppocr/data/imaug/ssl_img_aug.py diff --git a/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml b/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml new file mode 100644 index 00000000..1ffeba07 --- /dev/null +++ b/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml @@ -0,0 +1,99 @@ +Global: + debug: false + use_gpu: true + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_ppocr_v3_rotnet + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: null + checkpoints: null + save_inference_dir: null + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 25 + infer_mode: false + use_space_char: true + save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + regularizer: + name: L2 + factor: 1.0e-05 +Architecture: + model_type: cls + algorithm: CLS + Transform: null + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + last_conv_stride: [1, 2] + last_pool_type: avg + Neck: + Head: + name: ClsHead + class_dim: 4 + +Loss: + name: ClsLoss + main_indicator: acc + +PostProcess: + name: ClsPostProcess + +Metric: + name: ClsMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + use_tia: False + - RandAugment: + - SSLRotateResize: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: ["image", "label"] + loader: + collate_fn: "SSLRotateCollate" + shuffle: true + batch_size_per_card: 32 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - SSLRotateResize: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: ["image", "label"] + loader: + collate_fn: "SSLRotateCollate" + shuffle: false + drop_last: false + batch_size_per_card: 64 + num_workers: 8 +profiler_options: null diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 822fd5da..90f79dab 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -35,17 +35,7 @@ from ppocr.metrics import build_metric import tools.program as program from paddleslim.dygraph.quant import QAT from ppocr.data import build_dataloader - - -def export_single_model(quanter, model, infer_shape, save_path, logger): - quanter.save_quantized_model( - model, - save_path, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) - logger.info('inference QAT model is saved to {}'.format(save_path)) +from tools.export_model import export_single_model def main(): @@ -84,17 +74,54 @@ def main(): config['Global']) # build model - # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: - config['Architecture']["Models"][key]["Head"][ - 'out_channels'] = char_num + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][-1].keys())[ + 0] == 'DistillationSARLoss' + config['Loss']['loss_config_list'][-1][ + 'DistillationSARLoss']['ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][1].keys())[ + 0] == 'SARLoss' + if config['Loss']['loss_config_list'][1]['SARLoss'] is None: + config['Loss']['loss_config_list'][1]['SARLoss'] = { + 'ignore_index': char_num + 1 + } + else: + config['Loss']['loss_config_list'][1]['SARLoss'][ + 'ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num + if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model + config['Loss']['ignore_index'] = char_num - 1 + model = build_model(config['Architecture']) # get QAT model @@ -120,21 +147,22 @@ def main(): for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640] - save_path = config["Global"]["save_inference_dir"] arch_config = config["Architecture"] + + arch_config = config["Architecture"] + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + archs = list(arch_config["Models"].values()) for idx, name in enumerate(model.model_name_list): model.model_list[idx].eval() sub_model_save_path = os.path.join(save_path, name, "inference") - export_single_model(quanter, model.model_list[idx], infer_shape, - sub_model_save_path, logger) + export_single_model(model.model_list[idx], archs[idx], + sub_model_save_path, logger, quanter) else: save_path = os.path.join(save_path, "inference") - model.eval() - export_single_model(quanter, model, infer_shape, save_path, logger) + export_single_model(model, arch_config, save_path, logger, quanter) if __name__ == "__main__": diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 355ba77f..f7acb185 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer): if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: - config['Architecture']["Models"][key]["Head"][ - 'out_channels'] = char_num + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][-1].keys())[ + 0] == 'DistillationSARLoss' + config['Loss']['loss_config_list'][-1][ + 'DistillationSARLoss']['ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][1].keys())[ + 0] == 'SARLoss' + if config['Loss']['loss_config_list'][1]['SARLoss'] is None: + config['Loss']['loss_config_list'][1]['SARLoss'] = { + 'ignore_index': char_num + 1 + } + else: + config['Loss']['loss_config_list'][1]['SARLoss'][ + 'ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num + + if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model + config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) pre_best_model_dict = dict() diff --git a/doc/doc_ch/layout_datasets.md b/doc/doc_ch/layout_datasets.md index 45ac3a11..e7055b4e 100644 --- a/doc/doc_ch/layout_datasets.md +++ b/doc/doc_ch/layout_datasets.md @@ -27,7 +27,7 @@ #### 2、CDLA数据集 - **数据来源**:https://github.com/buptlihang/CDLA -- **数据简介**:publaynet数据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。 +- **数据简介**:CDLA据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 60ab7bd0..78c32796 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None): use_shared_memory = loader_config['use_shared_memory'] else: use_shared_memory = True + if mode == "Train": # Distribute data to multiple cards batch_sampler = DistributedBatchSampler( diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py index 89c6b4fd..0da6060f 100644 --- a/ppocr/data/collate_fn.py +++ b/ppocr/data/collate_fn.py @@ -56,3 +56,17 @@ class ListCollator(object): for idx in to_tensor_idxs: data_dict[idx] = paddle.to_tensor(data_dict[idx]) return list(data_dict.values()) + + +class SSLRotateCollate(object): + """ + bach: [ + [(4*3xH*W), (4,)] + [(4*3xH*W), (4,)] + ... + ] + """ + + def __call__(self, batch): + output = [np.concatenate(d, axis=0) for d in zip(*batch)] + return output diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 7580e607..20aaf48e 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -24,6 +24,7 @@ from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg +from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste from .ColorJitter import ColorJitter diff --git a/ppocr/data/imaug/ssl_img_aug.py b/ppocr/data/imaug/ssl_img_aug.py new file mode 100644 index 00000000..f9ed6ac3 --- /dev/null +++ b/ppocr/data/imaug/ssl_img_aug.py @@ -0,0 +1,60 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import cv2 +import numpy as np +import random +from PIL import Image + +from .rec_img_aug import resize_norm_img + + +class SSLRotateResize(object): + def __init__(self, + image_shape, + padding=False, + select_all=True, + mode="train", + **kwargs): + self.image_shape = image_shape + self.padding = padding + self.select_all = select_all + self.mode = mode + + def __call__(self, data): + img = data["image"] + + data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + data["image_r180"] = cv2.rotate(data["image_r90"], + cv2.ROTATE_90_CLOCKWISE) + data["image_r270"] = cv2.rotate(data["image_r180"], + cv2.ROTATE_90_CLOCKWISE) + + images = [] + for key in ["image", "image_r90", "image_r180", "image_r270"]: + images.append( + resize_norm_img( + data.pop(key), + image_shape=self.image_shape, + padding=self.padding)[0]) + data["image"] = np.stack(images, axis=0) + data["label"] = np.array(list(range(4))) + if not self.select_all: + data["image"] = data["image"][0::2] # just choose 0 and 180 + data["label"] = data["label"][0:2] # label needs to be continuous + if self.mode == "test": + data["image"] = data["image"][0] + data["label"] = data["label"][0] + return data diff --git a/ppocr/postprocess/cls_postprocess.py b/ppocr/postprocess/cls_postprocess.py index 77e7f46d..9a27ba08 100644 --- a/ppocr/postprocess/cls_postprocess.py +++ b/ppocr/postprocess/cls_postprocess.py @@ -17,17 +17,26 @@ import paddle class ClsPostProcess(object): """ Convert between text-label and text-index """ - def __init__(self, label_list, **kwargs): + def __init__(self, label_list=None, key=None, **kwargs): super(ClsPostProcess, self).__init__() self.label_list = label_list + self.key = key def __call__(self, preds, label=None, *args, **kwargs): + if self.key is not None: + preds = preds[self.key] + + label_list = self.label_list + if label_list is None: + label_list = {idx: idx for idx in range(preds.shape[-1])} + if isinstance(preds, paddle.Tensor): preds = preds.numpy() + pred_idxs = preds.argmax(axis=1) - decode_out = [(self.label_list[idx], preds[i, idx]) + decode_out = [(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)] if label is None: return decode_out - label = [(self.label_list[idx], 1.0) for idx in label] + label = [(label_list[idx], 1.0) for idx in label] return decode_out, label diff --git a/tools/export_model.py b/tools/export_model.py index 003bc61f..1f9f29e3 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def export_single_model(model, arch_config, save_path, logger): +def export_single_model(model, arch_config, save_path, logger, quanter=None): if arch_config["algorithm"] == "SRN": max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ @@ -95,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger): shape=[None] + infer_shape, dtype="float32") ]) - paddle.jit.save(model, save_path) + if quanter is None: + paddle.jit.save(model, save_path) + else: + quanter.save_quantized_model(model, save_path) logger.info("inference model is saved to {}".format(save_path)) return @@ -125,7 +128,6 @@ def main(): char_num = char_num - 2 out_channels_list['CTCLabelDecode'] = char_num out_channels_list['SARLabelDecode'] = char_num + 2 - loss_list = config['Loss']['loss_config_list'] config['Architecture']['Models'][key]['Head'][ 'out_channels_list'] = out_channels_list else: diff --git a/tools/infer_cls.py b/tools/infer_cls.py index 4be30bbb..7fd6b536 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -57,6 +57,8 @@ def main(): continue elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image'] + elif op_name == "SSLRotateResize": + op[op_name]["mode"] = "test" transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) -- GitLab