From f6532a0e51222c4385dd41a0f9de169f188ac29a Mon Sep 17 00:00:00 2001 From: andyjpaddle <87074272+andyjpaddle@users.noreply.github.com> Date: Tue, 26 Apr 2022 16:19:31 +0800 Subject: [PATCH] add ppocrv3 rec (#6033) * add ppocrv3 rec --- configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml | 131 ++++ .../ch_PP-OCRv3_rec_distillation.yml | 205 ++++++ ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 32 + ppocr/data/imaug/rec_img_aug.py | 58 +- ppocr/data/simple_dataset.py | 10 +- ppocr/losses/__init__.py | 3 +- ppocr/losses/basic_loss.py | 4 +- ppocr/losses/combined_loss.py | 2 + ppocr/losses/distillation_loss.py | 58 +- ppocr/losses/rec_multi_loss.py | 58 ++ ppocr/losses/rec_sar_loss.py | 3 +- ppocr/metrics/rec_metric.py | 12 +- ppocr/modeling/architectures/base_model.py | 6 +- .../architectures/distillation_model.py | 4 +- ppocr/modeling/backbones/__init__.py | 4 +- ppocr/modeling/backbones/rec_mv1_enhance.py | 15 +- ppocr/modeling/backbones/rec_svtrnet.py | 595 ++++++++++++++++++ ppocr/modeling/heads/__init__.py | 4 +- ppocr/modeling/heads/rec_multi_head.py | 73 +++ ppocr/modeling/heads/rec_sar_head.py | 14 +- ppocr/modeling/necks/rnn.py | 113 +++- ppocr/postprocess/__init__.py | 3 +- ppocr/postprocess/rec_postprocess.py | 38 ++ tools/eval.py | 32 +- tools/export_model.py | 34 +- tools/infer/predict_rec.py | 6 +- tools/infer_rec.py | 24 +- tools/program.py | 18 +- tools/train.py | 42 +- 30 files changed, 1548 insertions(+), 55 deletions(-) create mode 100644 configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml create mode 100644 configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml create mode 100644 ppocr/losses/rec_multi_loss.py create mode 100644 ppocr/modeling/backbones/rec_svtrnet.py create mode 100644 ppocr/modeling/heads/rec_multi_head.py diff --git a/configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml b/configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml new file mode 100644 index 00000000..c45a1a3c --- /dev/null +++ b/configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml @@ -0,0 +1,131 @@ +Global: + debug: false + use_gpu: true + epoch_num: 500 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_ppocr_v3 + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: &max_text_length 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_ppocrv3.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 5 + regularizer: + name: L2 + factor: 3.0e-05 + + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + last_conv_stride: [1, 2] + last_pool_type: avg + Head: + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length + +Loss: + name: MultiLoss + loss_config_list: + - CTCLoss: + - SARLoss: + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + ignore_space: True + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + ext_op_transform_idx: 1 + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecConAug: + prob: 0.5 + ext_data_num: 2 + image_shape: [48, 320, 3] + - RecAug: + - MultiLabelEncode: + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_sar + - length + - valid_ratio + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 4 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - MultiLabelEncode: + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_sar + - length + - valid_ratio + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 4 diff --git a/configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml new file mode 100644 index 00000000..80ec7c63 --- /dev/null +++ b/configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -0,0 +1,205 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_ppocr_v3_distillation + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: &max_text_length 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs : [700, 800] + values : [0.0005, 0.00005] + warmup_epoch: 5 + regularizer: + name: L2 + factor: 3.0e-05 + + +Architecture: + model_type: &model_type "rec" + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: SVTR + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + last_conv_stride: [1, 2] + last_pool_type: avg + Head: + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: SVTR + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + last_conv_stride: [1, 2] + last_pool_type: avg + Head: + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 64 + depth: 2 + hidden_dims: 120 + use_guide: True + Head: + fc_decay: 0.00001 + - SARHead: + enc_dim: 512 + max_text_length: *max_text_length +Loss: + name: CombinedLoss + loss_config_list: + - DistillationDMLLoss: + weight: 1.0 + act: "softmax" + use_log: true + model_name_pairs: + - ["Student", "Teacher"] + key: head_out + multi_head: True + dis_head: ctc + name: dml_ctc + - DistillationDMLLoss: + weight: 0.5 + act: "softmax" + use_log: true + model_name_pairs: + - ["Student", "Teacher"] + key: head_out + multi_head: True + dis_head: sar + name: dml_sar + - DistillationDistanceLoss: + weight: 1.0 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: backbone_out + - DistillationCTCLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: head_out + multi_head: True + - DistillationSARLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: head_out + multi_head: True + +PostProcess: + name: DistillationCTCLabelDecode + model_name: ["Student", "Teacher"] + key: head_out + multi_head: True + +Metric: + name: DistillationMetric + base_metric_name: RecMetric + main_indicator: acc + key: "Student" + ignore_space: True + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + ext_op_transform_idx: 1 + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecConAug: + prob: 0.5 + ext_data_num: 2 + image_shape: [48, 320, 3] + - RecAug: + - MultiLabelEncode: + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_sar + - length + - valid_ratio + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 4 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - MultiLabelEncode: + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_sar + - length + - valid_ratio + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 4 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 164f1d22..c24886aa 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, \ +from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 6f86be7d..86366d7a 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -22,6 +22,7 @@ import numpy as np import string from shapely.geometry import LineString, Point, Polygon import json +import copy from ppocr.utils.logging import get_logger @@ -1007,3 +1008,34 @@ class VQATokenLabelEncode(object): gt_label.extend([self.label2id_map[("i-" + label).upper()]] * (len(encode_res["input_ids"]) - 1)) return gt_label + + +class MultiLabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(MultiLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path, + use_space_char, **kwargs) + self.sar_encode = SARLabelEncode(max_text_length, character_dict_path, + use_space_char, **kwargs) + + def __call__(self, data): + + data_ctc = copy.deepcopy(data) + data_sar = copy.deepcopy(data) + data_out = dict() + data_out['img_path'] = data.get('img_path', None) + data_out['image'] = data['image'] + ctc = self.ctc_encode.__call__(data_ctc) + sar = self.sar_encode.__call__(data_sar) + if ctc is None or sar is None: + return None + data_out['label_ctc'] = ctc['label'] + data_out['label_sar'] = sar['label'] + data_out['length'] = ctc['length'] + return data_out diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 6f59fef6..960a11be 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -32,6 +32,49 @@ class RecAug(object): return data +class RecConAug(object): + def __init__(self, + prob=0.5, + image_shape=(32, 320, 3), + max_text_length=25, + ext_data_num=1, + **kwargs): + self.ext_data_num = ext_data_num + self.prob = prob + self.max_text_length = max_text_length + self.image_shape = image_shape + self.max_wh_ratio = self.image_shape[1] / self.image_shape[0] + + def merge_ext_data(self, data, ext_data): + ori_w = round(data['image'].shape[1] / data['image'].shape[0] * + self.image_shape[0]) + ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] * + self.image_shape[0]) + data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0])) + ext_data['image'] = cv2.resize(ext_data['image'], + (ext_w, self.image_shape[0])) + data['image'] = np.concatenate( + [data['image'], ext_data['image']], axis=1) + data["label"] += ext_data["label"] + return data + + def __call__(self, data): + rnd_num = random.random() + if rnd_num > self.prob: + return data + for idx, ext_data in enumerate(data["ext_data"]): + if len(data["label"]) + len(ext_data[ + "label"]) > self.max_text_length: + break + concat_ratio = data['image'].shape[1] / data['image'].shape[ + 0] + ext_data['image'].shape[1] / ext_data['image'].shape[0] + if concat_ratio > self.max_wh_ratio: + break + data = self.merge_ext_data(data, ext_data) + data.pop("ext_data") + return data + + class ClsResizeImg(object): def __init__(self, image_shape, **kwargs): self.image_shape = image_shape @@ -98,10 +141,13 @@ class RecResizeImg(object): def __call__(self, data): img = data['image'] if self.infer_mode and self.character_dict_path is not None: - norm_img = resize_norm_img_chinese(img, self.image_shape) + norm_img, valid_ratio = resize_norm_img_chinese(img, + self.image_shape) else: - norm_img = resize_norm_img(img, self.image_shape, self.padding) + norm_img, valid_ratio = resize_norm_img(img, self.image_shape, + self.padding) data['image'] = norm_img + data['valid_ratio'] = valid_ratio return data @@ -220,7 +266,8 @@ def resize_norm_img(img, image_shape, padding=True): resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image - return padding_im + valid_ratio = min(1.0, float(resized_w / imgW)) + return padding_im, valid_ratio def resize_norm_img_chinese(img, image_shape): @@ -230,7 +277,7 @@ def resize_norm_img_chinese(img, image_shape): h, w = img.shape[0], img.shape[1] ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, ratio) - imgW = int(32 * max_wh_ratio) + imgW = int(imgH * max_wh_ratio) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: @@ -246,7 +293,8 @@ def resize_norm_img_chinese(img, image_shape): resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image - return padding_im + valid_ratio = min(1.0, float(resized_w / imgW)) + return padding_im, valid_ratio def resize_norm_img_srn(img, image_shape): diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 13f9411e..b5da9b88 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -49,7 +49,8 @@ class SimpleDataSet(Dataset): if self.mode == "train" and self.do_shuffle: self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", + 2) self.need_reset = True in [x < 1 for x in ratio_list] def get_image_info_list(self, file_list, ratio_list): @@ -87,7 +88,7 @@ class SimpleDataSet(Dataset): if hasattr(op, 'ext_data_num'): ext_data_num = getattr(op, 'ext_data_num') break - load_data_ops = self.ops[:2] + load_data_ops = self.ops[:self.ext_op_transform_idx] ext_data = [] while len(ext_data) < ext_data_num: @@ -108,8 +109,11 @@ class SimpleDataSet(Dataset): data['image'] = img data = transform(data, load_data_ops) - if data is None or data['polys'].shape[1] != 4: + if data is None: continue + if 'polys' in data.keys(): + if data['polys'].shape[1] != 4: + continue ext_data.append(data) return ext_data diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 6505fca7..de8419b7 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -34,6 +34,7 @@ from .rec_nrtr_loss import NRTRLoss from .rec_sar_loss import SARLoss from .rec_aster_loss import AsterLoss from .rec_pren_loss import PRENLoss +from .rec_multi_loss import MultiLoss # cls loss from .cls_loss import ClsLoss @@ -60,7 +61,7 @@ def build_loss(config): 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', - 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss' + 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index b19ce57d..2df96ea2 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -106,8 +106,8 @@ class DMLLoss(nn.Layer): def forward(self, out1, out2): if self.act is not None: - out1 = self.act(out1) - out2 = self.act(out2) + out1 = self.act(out1) + 1e-10 + out2 = self.act(out2) + 1e-10 if self.use_log: # for recognition distillation, log is needed for feature map log_out1 = paddle.log(out1) diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index 72f706e3..f4cdee8f 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -18,8 +18,10 @@ import paddle.nn as nn from .rec_ctc_loss import CTCLoss from .center_loss import CenterLoss from .ace_loss import ACELoss +from .rec_sar_loss import SARLoss from .distillation_loss import DistillationCTCLoss +from .distillation_loss import DistillationSARLoss from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 06aa7fa8..565b066d 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -18,6 +18,7 @@ import numpy as np import cv2 from .rec_ctc_loss import CTCLoss +from .rec_sar_loss import SARLoss from .basic_loss import DMLLoss from .basic_loss import DistanceLoss from .det_db_loss import DBLoss @@ -46,11 +47,15 @@ class DistillationDMLLoss(DMLLoss): act=None, use_log=False, key=None, + multi_head=False, + dis_head='ctc', maps_name=None, name="dml"): super().__init__(act=act, use_log=use_log) assert isinstance(model_name_pairs, list) self.key = key + self.multi_head = multi_head + self.dis_head = dis_head self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.name = name self.maps_name = self._check_maps_name(maps_name) @@ -97,7 +102,11 @@ class DistillationDMLLoss(DMLLoss): out2 = out2[self.key] if self.maps_name is None: - loss = super().forward(out1, out2) + if self.multi_head: + loss = super().forward(out1[self.dis_head], + out2[self.dis_head]) + else: + loss = super().forward(out1, out2) if isinstance(loss, dict): for key in loss: loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], @@ -123,11 +132,16 @@ class DistillationDMLLoss(DMLLoss): class DistillationCTCLoss(CTCLoss): - def __init__(self, model_name_list=[], key=None, name="loss_ctc"): + def __init__(self, + model_name_list=[], + key=None, + multi_head=False, + name="loss_ctc"): super().__init__() self.model_name_list = model_name_list self.key = key self.name = name + self.multi_head = multi_head def forward(self, predicts, batch): loss_dict = dict() @@ -135,7 +149,45 @@ class DistillationCTCLoss(CTCLoss): out = predicts[model_name] if self.key is not None: out = out[self.key] - loss = super().forward(out, batch) + if self.multi_head: + assert 'ctc' in out, 'multi head has multi out' + loss = super().forward(out['ctc'], batch[:2] + batch[3:]) + else: + loss = super().forward(out, batch) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict + + +class DistillationSARLoss(SARLoss): + def __init__(self, + model_name_list=[], + key=None, + multi_head=False, + name="loss_sar", + **kwargs): + ignore_index = kwargs.get('ignore_index', 92) + super().__init__(ignore_index=ignore_index) + self.model_name_list = model_name_list + self.key = key + self.name = name + self.multi_head = multi_head + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + if self.multi_head: + assert 'sar' in out, 'multi head has multi out' + loss = super().forward(out['sar'], batch[:1] + batch[2:]) + else: + loss = super().forward(out, batch) if isinstance(loss, dict): for key in loss: loss_dict["{}_{}_{}".format(self.name, model_name, diff --git a/ppocr/losses/rec_multi_loss.py b/ppocr/losses/rec_multi_loss.py new file mode 100644 index 00000000..09f007af --- /dev/null +++ b/ppocr/losses/rec_multi_loss.py @@ -0,0 +1,58 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + +from .rec_ctc_loss import CTCLoss +from .rec_sar_loss import SARLoss + + +class MultiLoss(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.loss_funcs = {} + self.loss_list = kwargs.pop('loss_config_list') + self.weight_1 = kwargs.get('weight_1', 1.0) + self.weight_2 = kwargs.get('weight_2', 1.0) + self.gtc_loss = kwargs.get('gtc_loss', 'sar') + for loss_info in self.loss_list: + for name, param in loss_info.items(): + if param is not None: + kwargs.update(param) + loss = eval(name)(**kwargs) + self.loss_funcs[name] = loss + + def forward(self, predicts, batch): + self.total_loss = {} + total_loss = 0.0 + # batch [image, label_ctc, label_sar, length, valid_ratio] + for name, loss_func in self.loss_funcs.items(): + if name == 'CTCLoss': + loss = loss_func(predicts['ctc'], + batch[:2] + batch[3:])['loss'] * self.weight_1 + elif name == 'SARLoss': + loss = loss_func(predicts['sar'], + batch[:1] + batch[2:])['loss'] * self.weight_2 + else: + raise NotImplementedError( + '{} is not supported in MultiLoss yet'.format(name)) + self.total_loss[name] = loss + total_loss += loss + self.total_loss['loss'] = total_loss + return self.total_loss diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py index c8bd8bb0..a4f83f03 100644 --- a/ppocr/losses/rec_sar_loss.py +++ b/ppocr/losses/rec_sar_loss.py @@ -9,8 +9,9 @@ from paddle import nn class SARLoss(nn.Layer): def __init__(self, **kwargs): super(SARLoss, self).__init__() + ignore_index = kwargs.get('ignore_index', 92) # 6626 self.loss_func = paddle.nn.loss.CrossEntropyLoss( - reduction="mean", ignore_index=92) + reduction="mean", ignore_index=ignore_index) def forward(self, predicts, batch): predict = predicts[:, : diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index b047bbcb..515b9372 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -17,9 +17,14 @@ import string class RecMetric(object): - def __init__(self, main_indicator='acc', is_filter=False, **kwargs): + def __init__(self, + main_indicator='acc', + is_filter=False, + ignore_space=True, + **kwargs): self.main_indicator = main_indicator self.is_filter = is_filter + self.ignore_space = ignore_space self.eps = 1e-5 self.reset() @@ -34,8 +39,9 @@ class RecMetric(object): all_num = 0 norm_edit_dis = 0.0 for (pred, pred_conf), (target, _) in zip(preds, labels): - pred = pred.replace(" ", "") - target = target.replace(" ", "") + if self.ignore_space: + pred = pred.replace(" ", "") + target = target.replace(" ", "") if self.is_filter: pred = self._normalize_text(pred) target = self._normalize_text(target) diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index e622db25..f5b29f94 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -83,7 +83,11 @@ class BaseModel(nn.Layer): y["neck_out"] = x if self.use_head: x = self.head(x, targets=data) - if isinstance(x, dict): + # for multi head, save ctc neck out for udml + if isinstance(x, dict) and 'ctc_neck' in x.keys(): + y["neck_out"] = x["ctc_neck"] + y["head_out"] = x + elif isinstance(x, dict): y.update(x) else: y["head_out"] = x diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 5e867940..cce8fd31 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -53,8 +53,8 @@ class DistillationModel(nn.Layer): self.model_list.append(self.add_sublayer(key, model)) self.model_name_list.append(key) - def forward(self, x): + def forward(self, x, data=None): result_dict = dict() for idx, model_name in enumerate(self.model_name_list): - result_dict[model_name] = self.model_list[idx](x) + result_dict[model_name] = self.model_list[idx](x, data) return result_dict diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index c89c7c25..072d6e0f 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -31,9 +31,11 @@ def build_backbone(config, model_type): from .rec_resnet_aster import ResNet_ASTER from .rec_micronet import MicroNet from .rec_efficientb3_pren import EfficientNetb3_PREN + from .rec_svtrnet import SVTRNet support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', - "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN' + "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN', + 'SVTRNet' ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py index d8a7f4b5..bb6af5e8 100644 --- a/ppocr/modeling/backbones/rec_mv1_enhance.py +++ b/ppocr/modeling/backbones/rec_mv1_enhance.py @@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer): class MobileNetV1Enhance(nn.Layer): - def __init__(self, in_channels=3, scale=0.5, **kwargs): + def __init__(self, + in_channels=3, + scale=0.5, + last_conv_stride=1, + last_pool_type='max', + **kwargs): super().__init__() self.scale = scale self.block_list = [] @@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer): num_filters1=1024, num_filters2=1024, num_groups=1024, - stride=1, + stride=last_conv_stride, dw_size=5, padding=2, use_se=True, @@ -208,8 +213,10 @@ class MobileNetV1Enhance(nn.Layer): self.block_list.append(conv6) self.block_list = nn.Sequential(*self.block_list) - - self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + if last_pool_type == 'avg': + self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0) + else: + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.out_channels = int(1024 * scale) def forward(self, inputs): diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py new file mode 100644 index 00000000..bef8f368 --- /dev/null +++ b/ppocr/modeling/backbones/rec_svtrnet.py @@ -0,0 +1,595 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Callable +from paddle import ParamAttr +from paddle.nn.initializer import KaimingNormal +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal + +trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias_attr=False, + groups=1, + act=nn.GELU): + super().__init__() + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform()), + bias_attr=bias_attr) + self.norm = nn.BatchNorm2D(out_channels) + self.act = act() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMixer(nn.Layer): + def __init__( + self, + dim, + num_heads=8, + HW=[8, 25], + local_k=[3, 3], ): + super().__init__() + self.HW = HW + self.dim = dim + self.local_mixer = nn.Conv2D( + dim, + dim, + local_k, + 1, [local_k[0] // 2, local_k[1] // 2], + groups=num_heads, + weight_attr=ParamAttr(initializer=KaimingNormal())) + + def forward(self, x): + h = self.HW[0] + w = self.HW[1] + x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).transpose([0, 2, 1]) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + mixer='Global', + HW=[8, 25], + local_k=[7, 11], + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.HW = HW + if HW is not None: + H = HW[0] + W = HW[1] + self.N = H * W + self.C = dim + if mixer == 'Local' and HW is not None: + + hk = local_k[0] + wk = local_k[1] + mask = np.ones([H * W, H * W]) + for h in range(H): + for w in range(W): + for kh in range(-(hk // 2), (hk // 2) + 1): + for kw in range(-(wk // 2), (wk // 2) + 1): + if H > (h + kh) >= 0 and W > (w + kw) >= 0: + mask[h * W + w][(h + kh) * W + (w + kw)] = 0 + mask_paddle = paddle.to_tensor(mask, dtype='float32') + mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32') + mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf) + self.mask = mask.unsqueeze([0, 1]) + self.mixer = mixer + + def forward(self, x): + if self.HW is not None: + N = self.N + C = self.C + else: + _, N, C = x.shape + qkv = self.qkv(x).reshape((0, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) + if self.mixer == 'Local': + attn += self.mask + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mixer='Global', + local_mixer=[7, 11], + HW=[8, 25], + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-6, + prenorm=True): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm1 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") + if mixer == 'Global' or mixer == 'Local': + self.mixer = Attention( + dim, + num_heads=num_heads, + mixer=mixer, + HW=HW, + local_k=local_mixer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + elif mixer == 'Conv': + self.mixer = ConvMixer( + dim, num_heads=num_heads, HW=HW, local_k=local_mixer) + else: + raise TypeError("The mixer must be one of [Global, Local, Conv]") + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm2 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.prenorm = prenorm + + def forward(self, x): + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=[32, 100], + in_channels=3, + embed_dim=768, + sub_num=2): + super().__init__() + num_patches = (img_size[1] // (2 ** sub_num)) * \ + (img_size[0] // (2 ** sub_num)) + self.img_size = img_size + self.num_patches = num_patches + self.embed_dim = embed_dim + self.norm = None + if sub_num == 2: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels, + embed_dim // 2, + 3, + 2, + 1, + act=nn.GELU, + bias_attr=None), + ConvBNLayer( + embed_dim // 2, + embed_dim, + 3, + 2, + 1, + act=nn.GELU, + bias_attr=None)) + if sub_num == 3: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels, + embed_dim // 4, + 3, + 2, + 1, + act=nn.GELU, + bias_attr=None), + ConvBNLayer( + embed_dim // 4, + embed_dim // 2, + 3, + 2, + 1, + act=nn.GELU, + bias_attr=None), + ConvBNLayer( + embed_dim // 2, + embed_dim, + 3, + 2, + 1, + act=nn.GELU, + bias_attr=None), ) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose((0, 2, 1)) + return x + + +class SubSample(nn.Layer): + def __init__(self, + in_channels, + out_channels, + types='Pool', + stride=[2, 1], + sub_norm='nn.LayerNorm', + act=None): + super().__init__() + self.types = types + if types == 'Pool': + self.avgpool = nn.AvgPool2D( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.maxpool = nn.MaxPool2D( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=ParamAttr(initializer=KaimingNormal())) + self.norm = eval(sub_norm)(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x): + + if self.types == 'Pool': + x1 = self.avgpool(x) + x2 = self.maxpool(x) + x = (x1 + x2) * 0.5 + out = self.proj(x.flatten(2).transpose((0, 2, 1))) + else: + x = self.conv(x) + out = x.flatten(2).transpose((0, 2, 1)) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + + return out + + +class SVTRNet(nn.Layer): + def __init__( + self, + img_size=[32, 100], + in_channels=3, + embed_dim=[64, 128, 256], + depth=[3, 6, 3], + num_heads=[2, 4, 8], + mixer=['Local'] * 6 + ['Global'] * + 6, # Local atten, Global atten, Conv + local_mixer=[[7, 11], [7, 11], [7, 11]], + patch_merging='Conv', # Conv, Pool, None + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + last_drop=0.1, + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer='nn.LayerNorm', + sub_norm='nn.LayerNorm', + epsilon=1e-6, + out_channels=192, + out_char_num=25, + block_unit='Block', + act='nn.GELU', + last_stage=True, + sub_num=2, + prenorm=True, + use_lenhead=False, + **kwargs): + super().__init__() + self.img_size = img_size + self.embed_dim = embed_dim + self.out_channels = out_channels + self.prenorm = prenorm + patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging + self.patch_embed = PatchEmbed( + img_size=img_size, + in_channels=in_channels, + embed_dim=embed_dim[0], + sub_num=sub_num) + num_patches = self.patch_embed.num_patches + self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] + self.pos_embed = self.create_parameter( + shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) + self.add_parameter("pos_embed", self.pos_embed) + self.pos_drop = nn.Dropout(p=drop_rate) + Block_unit = eval(block_unit) + + dpr = np.linspace(0, drop_path_rate, sum(depth)) + self.blocks1 = nn.LayerList([ + Block_unit( + dim=embed_dim[0], + num_heads=num_heads[0], + mixer=mixer[0:depth[0]][i], + HW=self.HW, + local_mixer=local_mixer[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.Swish, + attn_drop=attn_drop_rate, + drop_path=dpr[0:depth[0]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[0]) + ]) + if patch_merging is not None: + self.sub_sample1 = SubSample( + embed_dim[0], + embed_dim[1], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 2, self.HW[1]] + else: + HW = self.HW + self.patch_merging = patch_merging + self.blocks2 = nn.LayerList([ + Block_unit( + dim=embed_dim[1], + num_heads=num_heads[1], + mixer=mixer[depth[0]:depth[0] + depth[1]][i], + HW=HW, + local_mixer=local_mixer[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0]:depth[0] + depth[1]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[1]) + ]) + if patch_merging is not None: + self.sub_sample2 = SubSample( + embed_dim[1], + embed_dim[2], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 4, self.HW[1]] + else: + HW = self.HW + self.blocks3 = nn.LayerList([ + Block_unit( + dim=embed_dim[2], + num_heads=num_heads[2], + mixer=mixer[depth[0] + depth[1]:][i], + HW=HW, + local_mixer=local_mixer[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1]:][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[2]) + ]) + self.last_stage = last_stage + if last_stage: + self.avg_pool = nn.AdaptiveAvgPool2D([1, out_char_num]) + self.last_conv = nn.Conv2D( + in_channels=embed_dim[2], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.hardswish = nn.Hardswish() + self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer") + if not prenorm: + self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) + self.use_lenhead = use_lenhead + if use_lenhead: + self.len_conv = nn.Linear(embed_dim[2], self.out_channels) + self.hardswish_len = nn.Hardswish() + self.dropout_len = nn.Dropout( + p=last_drop, mode="downscale_in_infer") + + trunc_normal_(self.pos_embed) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample1( + x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim[0], self.HW[0], self.HW[1]])) + for blk in self.blocks2: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample2( + x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) + for blk in self.blocks3: + x = blk(x) + if not self.prenorm: + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.use_lenhead: + len_x = self.len_conv(x.mean(1)) + len_x = self.dropout_len(self.hardswish_len(len_x)) + if self.last_stage: + if self.patch_merging is not None: + h = self.HW[0] // 4 + else: + h = self.HW[0] + x = self.avg_pool( + x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim[2], h, self.HW[1]])) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + if self.use_lenhead: + return x, len_x + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index b13fe2ec..1670ea38 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -32,6 +32,7 @@ def build_head(config): from .rec_sar_head import SARHead from .rec_aster_head import AsterHead from .rec_pren_head import PRENHead + from .rec_multi_head import MultiHead # cls head from .cls_head import ClsHead @@ -44,7 +45,8 @@ def build_head(config): support_dict = [ 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', - 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead' + 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', + 'MultiHead' ] #table head diff --git a/ppocr/modeling/heads/rec_multi_head.py b/ppocr/modeling/heads/rec_multi_head.py new file mode 100644 index 00000000..2f10e7bd --- /dev/null +++ b/ppocr/modeling/heads/rec_multi_head.py @@ -0,0 +1,73 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR +from .rec_ctc_head import CTCHead +from .rec_sar_head import SARHead + + +class MultiHead(nn.Layer): + def __init__(self, in_channels, out_channels_list, **kwargs): + super().__init__() + self.head_list = kwargs.pop('head_list') + self.gtc_head = 'sar' + assert len(self.head_list) >= 2 + for idx, head_name in enumerate(self.head_list): + name = list(head_name)[0] + if name == 'SARHead': + # sar head + sar_args = self.head_list[idx][name] + self.sar_head = eval(name)(in_channels=in_channels, \ + out_channels=out_channels_list['SARLabelDecode'], **sar_args) + elif name == 'CTCHead': + # ctc neck + self.encoder_reshape = Im2Seq(in_channels) + neck_args = self.head_list[idx][name]['Neck'] + encoder_type = neck_args.pop('name') + self.encoder = encoder_type + self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \ + encoder_type=encoder_type, **neck_args) + # ctc head + head_args = self.head_list[idx][name]['Head'] + self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \ + out_channels=out_channels_list['CTCLabelDecode'], **head_args) + else: + raise NotImplementedError( + '{} is not supported in MultiHead yet'.format(name)) + + def forward(self, x, targets=None): + ctc_encoder = self.ctc_encoder(x) + ctc_out = self.ctc_head(ctc_encoder, targets) + head_out = dict() + head_out['ctc'] = ctc_out + head_out['ctc_neck'] = ctc_encoder + # eval mode + if not self.training: + return ctc_out + if self.gtc_head == 'sar': + sar_out = self.sar_head(x, targets[1:]) + head_out['sar'] = sar_out + return head_out + else: + return head_out diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py index 3b767426..27693ebc 100644 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder): class SARHead(nn.Layer): def __init__(self, + in_channels, out_channels, + enc_dim=512, + max_text_length=30, enc_bi_rnn=False, enc_drop_rnn=0.1, enc_gru=False, @@ -358,14 +361,17 @@ class SARHead(nn.Layer): dec_gru=False, d_k=512, pred_dropout=0.1, - max_text_length=30, pred_concat=True, **kwargs): super(SARHead, self).__init__() # encoder module self.encoder = SAREncoder( - enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru) + enc_bi_rnn=enc_bi_rnn, + enc_drop_rnn=enc_drop_rnn, + enc_gru=enc_gru, + d_model=in_channels, + d_enc=enc_dim) # decoder module self.decoder = ParallelSARDecoder( @@ -374,6 +380,8 @@ class SARHead(nn.Layer): dec_bi_rnn=dec_bi_rnn, dec_drop_rnn=dec_drop_rnn, dec_gru=dec_gru, + d_model=in_channels, + d_enc=enc_dim, d_k=d_k, pred_dropout=pred_dropout, max_text_length=max_text_length, @@ -390,7 +398,7 @@ class SARHead(nn.Layer): label = paddle.to_tensor(label, dtype='int64') final_out = self.decoder( feat, holistic_feat, label, img_metas=targets) - if not self.training: + else: final_out = self.decoder( feat, holistic_feat, diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index 86e64902..c8a774b8 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -16,9 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import paddle from paddle import nn from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr +from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_ class Im2Seq(nn.Layer): @@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer): return x +class EncoderWithSVTR(nn.Layer): + def __init__( + self, + in_channels, + dims=64, # XS + depth=2, + hidden_dims=120, + use_guide=False, + num_heads=8, + qkv_bias=True, + mlp_ratio=2.0, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path=0., + qk_scale=None): + super(EncoderWithSVTR, self).__init__() + self.depth = depth + self.use_guide = use_guide + self.conv1 = ConvBNLayer( + in_channels, in_channels // 8, padding=1, act=nn.Swish) + self.conv2 = ConvBNLayer( + in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish) + + self.svtr_block = nn.LayerList([ + Block( + dim=hidden_dims, + num_heads=num_heads, + mixer='Global', + HW=None, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.Swish, + attn_drop=attn_drop_rate, + drop_path=drop_path, + norm_layer='nn.LayerNorm', + epsilon=1e-05, + prenorm=False) for i in range(depth) + ]) + self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6) + self.conv3 = ConvBNLayer( + hidden_dims, in_channels, kernel_size=1, act=nn.Swish) + # last conv-nxn, the input is concat of input tensor and conv3 output tensor + self.conv4 = ConvBNLayer( + 2 * in_channels, in_channels // 8, padding=1, act=nn.Swish) + + self.conv1x1 = ConvBNLayer( + in_channels // 8, dims, kernel_size=1, act=nn.Swish) + self.out_channels = dims + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, x): + # for use guide + if self.use_guide: + z = x.clone() + z.stop_gradient = True + else: + z = x + # for short cut + h = z + # reduce dim + z = self.conv1(z) + z = self.conv2(z) + # SVTR global block + B, C, H, W = z.shape + z = z.flatten(2).transpose([0, 2, 1]) + for blk in self.svtr_block: + z = blk(z) + z = self.norm(z) + # last stage + z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2]) + z = self.conv3(z) + z = paddle.concat((h, z), axis=1) + z = self.conv1x1(self.conv4(z)) + return z + + class SequenceEncoder(nn.Layer): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): super(SequenceEncoder, self).__init__() self.encoder_reshape = Im2Seq(in_channels) self.out_channels = self.encoder_reshape.out_channels + self.encoder_type = encoder_type if encoder_type == 'reshape': self.only_reshape = True else: support_encoder_dict = { 'reshape': Im2Seq, 'fc': EncoderWithFC, - 'rnn': EncoderWithRNN + 'rnn': EncoderWithRNN, + 'svtr': EncoderWithSVTR } assert encoder_type in support_encoder_dict, '{} must in {}'.format( encoder_type, support_encoder_dict.keys()) - - self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, hidden_size) + if encoder_type == "svtr": + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels, **kwargs) + else: + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels, hidden_size) self.out_channels = self.encoder.out_channels self.only_reshape = False def forward(self, x): - x = self.encoder_reshape(x) - if not self.only_reshape: + if self.encoder_type != 'svtr': + x = self.encoder_reshape(x) + if not self.only_reshape: + x = self.encoder(x) + return x + else: x = self.encoder(x) - return x + x = self.encoder_reshape(x) + return x diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 14be63dd..f50b5f1c 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -41,7 +41,8 @@ def build_post_process(config, global_config=None): 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', - 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode' + 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', + 'DistillationSARLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 47825dc7..bf0fd890 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode): use_space_char=False, model_name=["student"], key=None, + multi_head=False, **kwargs): super(DistillationCTCLabelDecode, self).__init__(character_dict_path, use_space_char) @@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode): self.model_name = model_name self.key = key + self.multi_head = multi_head def __call__(self, preds, label=None, *args, **kwargs): output = dict() @@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode): pred = preds[name] if self.key is not None: pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred['ctc'] output[name] = super().__call__(pred, label=label, *args, **kwargs) return output @@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode): return [self.padding_idx] +class DistillationSARLabelDecode(SARLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs): + super(DistillationSARLabelDecode, self).__init__(character_dict_path, + use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred['sar'] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + class PRENLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/tools/eval.py b/tools/eval.py index f6fcf14c..1038090a 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -47,14 +47,38 @@ def main(): if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: - config['Architecture']["Models"][key]["Head"][ - 'out_channels'] = char_num + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + out_channels_list = {} + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head + out_channels_list = {} + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - extra_input = config['Architecture'][ - 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + if config['Architecture']['algorithm'] == 'Distillation': + extra_input = config['Architecture']['Models']['Teacher'][ + 'algorithm'] in extra_input_models + else: + extra_input = config['Architecture']['algorithm'] in extra_input_models if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: diff --git a/tools/export_model.py b/tools/export_model.py index bd647fc7..96cc05a2 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -55,6 +55,13 @@ def export_single_model(model, arch_config, save_path, logger): shape=[None, 3, 48, 160], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SVTR": + if arch_config["Head"]["name"] == 'MultiHead': + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, -1], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "PREN": other_shape = [ paddle.static.InputSpec( @@ -105,13 +112,36 @@ def main(): if config["Architecture"]["algorithm"] in ["Distillation", ]: # distillation model for key in config["Architecture"]["Models"]: - config["Architecture"]["Models"][key]["Head"][ - "out_channels"] = char_num + if config["Architecture"]["Models"][key]["Head"][ + "name"] == 'MultiHead': # multi head + out_channels_list = {} + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + loss_list = config['Loss']['loss_config_list'] + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config["Architecture"]["Models"][key]["Head"][ + "out_channels"] = char_num # just one final tensor needs to to exported for inference config["Architecture"]["Models"][key][ "return_all_feats"] = False + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # multi head + out_channels_list = {} + char_num = len(getattr(post_process_class, 'character')) + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num + model = build_model(config["Architecture"]) load_model(config, model) model.eval() diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c5aacb06..d4fbc388 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -107,7 +107,7 @@ class TextRecognizer(object): return norm_img.astype(np.float32) / 128. - 1. assert imgC == img.shape[2] - imgW = int((32 * max_wh_ratio)) + imgW = int((imgH * max_wh_ratio)) if self.use_onnx: w = self.input_tensor.shape[3:][0] if w is not None and w > 0: @@ -255,7 +255,9 @@ class TextRecognizer(object): for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] - max_wh_ratio = 0 + imgC, imgH, imgW = self.rec_image_shape + max_wh_ratio = imgW / imgH + # max_wh_ratio = 0 for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 02b3afd8..63d410b6 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -51,8 +51,28 @@ def main(): if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: - config['Architecture']["Models"][key]["Head"][ - 'out_channels'] = char_num + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + out_channels_list = {} + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head loss + out_channels_list = {} + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num diff --git a/tools/program.py b/tools/program.py index 8ec152bb..1742f6c9 100755 --- a/tools/program.py +++ b/tools/program.py @@ -201,12 +201,17 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - extra_input = config['Architecture'][ - 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + if config['Architecture']['algorithm'] == 'Distillation': + extra_input = config['Architecture']['Models']['Teacher'][ + 'algorithm'] in extra_input_models + else: + extra_input = config['Architecture']['algorithm'] in extra_input_models try: model_type = config['Architecture']['model_type'] except: model_type = None + algorithm = config['Architecture']['algorithm'] start_epoch = best_model_dict[ @@ -269,7 +274,12 @@ def train(config, if model_type in ['table', 'kie']: eval_class(preds, batch) else: - post_result = post_process_class(preds, batch[1]) + if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' + ]: # for multi head loss + post_result = post_process_class( + preds['ctc'], batch[1]) # for CTC head out + else: + post_result = post_process_class(preds, batch[1]) eval_class(post_result, batch) metric = eval_class.get_metric() train_stats.update(metric) @@ -541,7 +551,7 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' ] device = 'cpu' diff --git a/tools/train.py b/tools/train.py index f6cd0e7d..77e600ab 100755 --- a/tools/train.py +++ b/tools/train.py @@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer): if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: - config['Architecture']["Models"][key]["Head"][ - 'out_channels'] = char_num + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][-1].keys())[ + 0] == 'DistillationSARLoss' + config['Loss']['loss_config_list'][-1][ + 'DistillationSARLoss']['ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][1].keys())[ + 0] == 'SARLoss' + if config['Loss']['loss_config_list'][1]['SARLoss'] is None: + config['Loss']['loss_config_list'][1]['SARLoss'] = { + 'ignore_index': char_num + 1 + } + else: + config['Loss']['loss_config_list'][1]['SARLoss'][ + 'ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num + if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model + config['Loss']['ignore_index'] = char_num - 1 + model = build_model(config['Architecture']) if config['Global']['distributed']: model = paddle.DataParallel(model) -- GitLab