From 6e607a0fa1cefbf0388dac86c84debf4781cec48 Mon Sep 17 00:00:00 2001 From: OneYearIsEnough <81819512+OneYearIsEnough@users.noreply.github.com> Date: Mon, 28 Feb 2022 21:48:00 +0800 Subject: [PATCH] [Feature] Add PREN Scene Text Recognition Model(Accepted in CVPR2021) (#5563) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Feature] add PREN scene text recognition model * [Patch] Optimize yml File * [Patch] Save Label/Pred Preprocess Time Cost * [BugFix] Modify Shape Conversion to Fit for Inference Model Exportion * [Patch] ? * [Patch] ? * 啥情况... --- configs/rec/rec_efficientb3_fpn_pren.yml | 92 +++++++ ppocr/data/imaug/__init__.py | 3 +- ppocr/data/imaug/label_ops.py | 47 ++++ ppocr/data/imaug/rec_img_aug.py | 19 ++ ppocr/losses/__init__.py | 3 +- ppocr/losses/rec_pren_loss.py | 30 +++ ppocr/modeling/backbones/__init__.py | 3 +- .../backbones/rec_efficientb3_pren.py | 228 ++++++++++++++++++ ppocr/modeling/heads/__init__.py | 3 +- ppocr/modeling/heads/rec_pren_head.py | 34 +++ ppocr/modeling/necks/__init__.py | 6 +- ppocr/modeling/necks/pren_fpn.py | 163 +++++++++++++ ppocr/postprocess/__init__.py | 7 +- ppocr/postprocess/rec_postprocess.py | 62 ++++- tools/eval.py | 1 - tools/export_model.py | 6 + tools/program.py | 2 +- 17 files changed, 698 insertions(+), 11 deletions(-) create mode 100644 configs/rec/rec_efficientb3_fpn_pren.yml create mode 100644 ppocr/losses/rec_pren_loss.py create mode 100644 ppocr/modeling/backbones/rec_efficientb3_pren.py create mode 100644 ppocr/modeling/heads/rec_pren_head.py create mode 100644 ppocr/modeling/necks/pren_fpn.py diff --git a/configs/rec/rec_efficientb3_fpn_pren.yml b/configs/rec/rec_efficientb3_fpn_pren.yml new file mode 100644 index 00000000..0fac6a7a --- /dev/null +++ b/configs/rec/rec_efficientb3_fpn_pren.yml @@ -0,0 +1,92 @@ +Global: + use_gpu: True + epoch_num: 8 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/rec/pren_new + save_epoch_step: 3 + # evaluation is run every 2000 iterations after the 4000th iteration + eval_batch_step: [4000, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + max_text_length: &max_text_length 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_pren.txt + +Optimizer: + name: Adadelta + lr: + name: Piecewise + decay_epochs: [2, 5, 7] + values: [0.5, 0.1, 0.01, 0.001] + +Architecture: + model_type: rec + algorithm: PREN + in_channels: 3 + Backbone: + name: EfficientNetb3_PREN + Neck: + name: PRENFPN + n_r: 5 + d_model: 384 + max_len: *max_text_length + dropout: 0.1 + Head: + name: PRENHead + +Loss: + name: PRENLoss + +PostProcess: + name: PRENLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: + img_mode: BGR + channel_first: False + - PRENLabelEncode: + - RecAug: + - PRENResizeImg: + image_shape: [64, 256] # h,w + - KeepKeys: + keep_keys: ['image', 'label'] + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/validation/ + transforms: + - DecodeImage: + img_mode: BGR + channel_first: False + - PRENLabelEncode: + - PRENResizeImg: + image_shape: [64, 256] # h,w + - KeepKeys: + keep_keys: ['image', 'label'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 64 + num_workers: 8 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 90a70875..b82725ff 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -22,7 +22,8 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, \ + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .ColorJitter import ColorJitter diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index ef962b17..6f86be7d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -785,6 +785,53 @@ class SARLabelEncode(BaseRecLabelEncode): return [self.padding_idx] +class PRENLabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path, + use_space_char=False, + **kwargs): + super(PRENLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + padding_str = '' # 0 + end_str = '' # 1 + unknown_str = '' # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def encode(self, text): + if len(text) == 0 or len(text) >= self.max_text_len: + return None + if self.lower: + text = text.lower() + text_list = [] + for char in text: + if char not in self.dict: + text_list.append(self.unknown_idx) + else: + text_list.append(self.dict[char]) + text_list.append(self.end_idx) + if len(text_list) < self.max_text_len: + text_list += [self.padding_idx] * ( + self.max_text_len - len(text_list)) + return text_list + + def __call__(self, data): + text = data['label'] + encoded_text = self.encode(text) + if encoded_text is None: + return None + data['label'] = np.array(encoded_text) + return data + + class VQATokenLabelEncode(object): """ Label encode for NLP VQA methods diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index b4de6de9..6f59fef6 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -141,6 +141,25 @@ class SARRecResizeImg(object): return data +class PRENResizeImg(object): + def __init__(self, image_shape, **kwargs): + """ + Accroding to original paper's realization, it's a hard resize method here. + So maybe you should optimize it to fit for your task better. + """ + self.dst_h, self.dst_w = image_shape + + def __call__(self, data): + img = data['image'] + resized_img = cv2.resize( + img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR) + resized_img = resized_img.transpose((2, 0, 1)) / 255 + resized_img -= 0.5 + resized_img /= 0.5 + data['image'] = resized_img.astype(np.float32) + return data + + def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 56e6d25d..d7f2b1c1 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -32,6 +32,7 @@ from .rec_srn_loss import SRNLoss from .rec_nrtr_loss import NRTRLoss from .rec_sar_loss import SARLoss from .rec_aster_loss import AsterLoss +from .rec_pren_loss import PRENLoss # cls loss from .cls_loss import ClsLoss @@ -58,7 +59,7 @@ def build_loss(config): 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', - 'VQASerTokenLayoutLMLoss', 'LossFromOutput' + 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_pren_loss.py b/ppocr/losses/rec_pren_loss.py new file mode 100644 index 00000000..7bc53d29 --- /dev/null +++ b/ppocr/losses/rec_pren_loss.py @@ -0,0 +1,30 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn + + +class PRENLoss(nn.Layer): + def __init__(self, **kwargs): + super(PRENLoss, self).__init__() + # note: 0 is padding idx + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) + + def forward(self, predicts, batch): + loss = self.loss_func(predicts, batch[1].astype('int64')) + return {'loss': loss} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index b34b7550..c89c7c25 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -30,9 +30,10 @@ def build_backbone(config, model_type): from .rec_resnet_31 import ResNet31 from .rec_resnet_aster import ResNet_ASTER from .rec_micronet import MicroNet + from .rec_efficientb3_pren import EfficientNetb3_PREN support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', - "ResNet31", "ResNet_ASTER", 'MicroNet' + "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN' ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_efficientb3_pren.py b/ppocr/modeling/backbones/rec_efficientb3_pren.py new file mode 100644 index 00000000..57eef178 --- /dev/null +++ b/ppocr/modeling/backbones/rec_efficientb3_pren.py @@ -0,0 +1,228 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Code is refer from: +https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from collections import namedtuple +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ['EfficientNetb3'] + + +class EffB3Params: + @staticmethod + def get_global_params(): + """ + The fllowing are efficientnetb3's arch superparams, but to fit for scene + text recognition task, the resolution(image_size) here is changed + from 300 to 64. + """ + GlobalParams = namedtuple('GlobalParams', [ + 'drop_connect_rate', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'image_size' + ]) + global_params = GlobalParams( + drop_connect_rate=0.3, + width_coefficient=1.2, + depth_coefficient=1.4, + depth_divisor=8, + image_size=64) + return global_params + + @staticmethod + def get_block_params(): + BlockParams = namedtuple('BlockParams', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'se_ratio', 'stride' + ]) + block_params = [ + BlockParams(3, 1, 32, 16, 1, True, 0.25, 1), + BlockParams(3, 2, 16, 24, 6, True, 0.25, 2), + BlockParams(5, 2, 24, 40, 6, True, 0.25, 2), + BlockParams(3, 3, 40, 80, 6, True, 0.25, 2), + BlockParams(5, 3, 80, 112, 6, True, 0.25, 1), + BlockParams(5, 4, 112, 192, 6, True, 0.25, 2), + BlockParams(3, 1, 192, 320, 6, True, 0.25, 1) + ] + return block_params + + +class EffUtils: + @staticmethod + def round_filters(filters, global_params): + """Calculate and round number of filters based on depth multiplier.""" + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + filters *= multiplier + new_filters = int(filters + divisor / 2) // divisor * divisor + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + @staticmethod + def round_repeats(repeats, global_params): + """Round number of filters based on depth multiplier.""" + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +class ConvBlock(nn.Layer): + def __init__(self, block_params): + super(ConvBlock, self).__init__() + self.block_args = block_params + self.has_se = (self.block_args.se_ratio is not None) and \ + (0 < self.block_args.se_ratio <= 1) + self.id_skip = block_params.id_skip + + # expansion phase + self.input_filters = self.block_args.input_filters + output_filters = \ + self.block_args.input_filters * self.block_args.expand_ratio + if self.block_args.expand_ratio != 1: + self.expand_conv = nn.Conv2D( + self.input_filters, output_filters, 1, bias_attr=False) + self.bn0 = nn.BatchNorm(output_filters) + + # depthwise conv phase + k = self.block_args.kernel_size + s = self.block_args.stride + self.depthwise_conv = nn.Conv2D( + output_filters, + output_filters, + groups=output_filters, + kernel_size=k, + stride=s, + padding='same', + bias_attr=False) + self.bn1 = nn.BatchNorm(output_filters) + + # squeeze and excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, + int(self.block_args.input_filters * + self.block_args.se_ratio)) + self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1) + self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1) + + # output phase + self.final_oup = self.block_args.output_filters + self.project_conv = nn.Conv2D( + output_filters, self.final_oup, 1, bias_attr=False) + self.bn2 = nn.BatchNorm(self.final_oup) + self.swish = nn.Swish() + + def drop_connect(self, inputs, p, training): + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + random_tensor = keep_prob + random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype) + random_tensor = paddle.to_tensor(random_tensor, place=inputs.place) + binary_tensor = paddle.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + def forward(self, inputs, drop_connect_rate=None): + # expansion and depthwise conv + x = inputs + if self.block_args.expand_ratio != 1: + x = self.swish(self.bn0(self.expand_conv(inputs))) + x = self.swish(self.bn1(self.depthwise_conv(x))) + + # squeeze and excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed))) + x = F.sigmoid(x_squeezed) * x + x = self.bn2(self.project_conv(x)) + + # skip conntection and drop connect + if self.id_skip and self.block_args.stride == 1 and \ + self.input_filters == self.final_oup: + if drop_connect_rate: + x = self.drop_connect( + x, p=drop_connect_rate, training=self.training) + x = x + inputs + return x + + +class EfficientNetb3_PREN(nn.Layer): + def __init__(self, in_channels): + super(EfficientNetb3_PREN, self).__init__() + self.blocks_params = EffB3Params.get_block_params() + self.global_params = EffB3Params.get_global_params() + self.out_channels = [] + # stem + stem_channels = EffUtils.round_filters(32, self.global_params) + self.conv_stem = nn.Conv2D( + in_channels, stem_channels, 3, 2, padding='same', bias_attr=False) + self.bn0 = nn.BatchNorm(stem_channels) + + self.blocks = [] + # to extract three feature maps for fpn based on efficientnetb3 backbone + self.concerned_block_idxes = [7, 17, 25] + concerned_idx = 0 + for i, block_params in enumerate(self.blocks_params): + block_params = block_params._replace( + input_filters=EffUtils.round_filters(block_params.input_filters, + self.global_params), + output_filters=EffUtils.round_filters( + block_params.output_filters, self.global_params), + num_repeat=EffUtils.round_repeats(block_params.num_repeat, + self.global_params)) + self.blocks.append( + self.add_sublayer("{}-0".format(i), ConvBlock(block_params))) + concerned_idx += 1 + if concerned_idx in self.concerned_block_idxes: + self.out_channels.append(block_params.output_filters) + if block_params.num_repeat > 1: + block_params = block_params._replace( + input_filters=block_params.output_filters, stride=1) + for j in range(block_params.num_repeat - 1): + self.blocks.append( + self.add_sublayer('{}-{}'.format(i, j + 1), + ConvBlock(block_params))) + concerned_idx += 1 + if concerned_idx in self.concerned_block_idxes: + self.out_channels.append(block_params.output_filters) + + self.swish = nn.Swish() + + def forward(self, inputs): + outs = [] + + x = self.swish(self.bn0(self.conv_stem(inputs))) + for idx, block in enumerate(self.blocks): + drop_connect_rate = self.global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self.blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + if idx in self.concerned_block_idxes: + outs.append(x) + return outs diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 4a27ce52..a62c8bf8 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -30,6 +30,7 @@ def build_head(config): from .rec_nrtr_head import Transformer from .rec_sar_head import SARHead from .rec_aster_head import AsterHead + from .rec_pren_head import PRENHead # cls head from .cls_head import ClsHead @@ -42,7 +43,7 @@ def build_head(config): support_dict = [ 'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', - 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead' + 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead' ] #table head diff --git a/ppocr/modeling/heads/rec_pren_head.py b/ppocr/modeling/heads/rec_pren_head.py new file mode 100644 index 00000000..c9e4b3e9 --- /dev/null +++ b/ppocr/modeling/heads/rec_pren_head.py @@ -0,0 +1,34 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +from paddle.nn import functional as F + + +class PRENHead(nn.Layer): + def __init__(self, in_channels, out_channels, **kwargs): + super(PRENHead, self).__init__() + self.linear = nn.Linear(in_channels, out_channels) + + def forward(self, x, targets=None): + predicts = self.linear(x) + + if not self.training: + predicts = F.softmax(predicts, axis=2) + + return predicts diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 5606a4c3..fdece49d 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -23,7 +23,11 @@ def build_neck(config): from .pg_fpn import PGFPN from .table_fpn import TableFPN from .fpn import FPN - support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN'] + from .pren_fpn import PRENFPN + support_dict = [ + 'FPN', 'DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', + 'TableFPN', 'PRENFPN' + ] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/pren_fpn.py b/ppocr/modeling/necks/pren_fpn.py new file mode 100644 index 00000000..afbdcea8 --- /dev/null +++ b/ppocr/modeling/necks/pren_fpn.py @@ -0,0 +1,163 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Code is refer from: +https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F + + +class PoolAggregate(nn.Layer): + def __init__(self, n_r, d_in, d_middle=None, d_out=None): + super(PoolAggregate, self).__init__() + if not d_middle: + d_middle = d_in + if not d_out: + d_out = d_in + + self.d_in = d_in + self.d_middle = d_middle + self.d_out = d_out + self.act = nn.Swish() + + self.n_r = n_r + self.aggs = self._build_aggs() + + def _build_aggs(self): + aggs = [] + for i in range(self.n_r): + aggs.append( + self.add_sublayer( + '{}'.format(i), + nn.Sequential( + ('conv1', nn.Conv2D( + self.d_in, self.d_middle, 3, 2, 1, bias_attr=False) + ), ('bn1', nn.BatchNorm(self.d_middle)), + ('act', self.act), ('conv2', nn.Conv2D( + self.d_middle, self.d_out, 3, 2, 1, bias_attr=False + )), ('bn2', nn.BatchNorm(self.d_out))))) + return aggs + + def forward(self, x): + b = x.shape[0] + outs = [] + for agg in self.aggs: + y = agg(x) + p = F.adaptive_avg_pool2d(y, 1) + outs.append(p.reshape((b, 1, self.d_out))) + out = paddle.concat(outs, 1) + return out + + +class WeightAggregate(nn.Layer): + def __init__(self, n_r, d_in, d_middle=None, d_out=None): + super(WeightAggregate, self).__init__() + if not d_middle: + d_middle = d_in + if not d_out: + d_out = d_in + + self.n_r = n_r + self.d_out = d_out + self.act = nn.Swish() + + self.conv_n = nn.Sequential( + ('conv1', nn.Conv2D( + d_in, d_in, 3, 1, 1, + bias_attr=False)), ('bn1', nn.BatchNorm(d_in)), + ('act1', self.act), ('conv2', nn.Conv2D( + d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)), + ('act2', nn.Sigmoid())) + self.conv_d = nn.Sequential( + ('conv1', nn.Conv2D( + d_in, d_middle, 3, 1, 1, + bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)), + ('act1', self.act), ('conv2', nn.Conv2D( + d_middle, d_out, 1, + bias_attr=False)), ('bn2', nn.BatchNorm(d_out))) + + def forward(self, x): + b, _, h, w = x.shape + + hmaps = self.conv_n(x) + fmaps = self.conv_d(x) + r = paddle.bmm( + hmaps.reshape((b, self.n_r, h * w)), + fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1))) + return r + + +class GCN(nn.Layer): + def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1): + super(GCN, self).__init__() + if not d_out: + d_out = d_in + if not n_out: + n_out = d_in + + self.conv_n = nn.Conv1D(n_in, n_out, 1) + self.linear = nn.Linear(d_in, d_out) + self.dropout = nn.Dropout(dropout) + self.act = nn.Swish() + + def forward(self, x): + x = self.conv_n(x) + x = self.dropout(self.linear(x)) + return self.act(x) + + +class PRENFPN(nn.Layer): + def __init__(self, in_channels, n_r, d_model, max_len, dropout): + super(PRENFPN, self).__init__() + assert len(in_channels) == 3, "in_channels' length must be 3." + c1, c2, c3 = in_channels # the depths are from big to small + # build fpn + assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model) + self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3) + self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3) + self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3) + + self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3) + self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3) + self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3) + + self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout) + self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout) + + self.out_channels = d_model + + def forward(self, inputs): + f3, f5, f7 = inputs + + rp1 = self.agg_p1(f3) + rp2 = self.agg_p2(f5) + rp3 = self.agg_p3(f7) + rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d] + + rw1 = self.agg_w1(f3) + rw2 = self.agg_w2(f5) + rw3 = self.agg_w3(f7) + rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d] + + y1 = self.gcn_pool(rp) + y2 = self.gcn_weight(rw) + y = 0.5 * (y1 + y2) + return y # [b,max_len,d] diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 811bf57b..8caea69c 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,8 +24,9 @@ __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ - TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ + DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ + SEEDLabelDecode, PRENLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -39,7 +40,7 @@ def build_post_process(config, global_config=None): 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', - 'VQAReTokenLayoutLMPostProcess' + 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index caaa2948..93d38554 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np -import string import paddle from paddle.nn import functional as F import re @@ -652,3 +652,63 @@ class SARLabelDecode(BaseRecLabelDecode): def get_ignored_tokens(self): return [self.padding_idx] + + +class PRENLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(PRENLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + padding_str = '' # 0 + end_str = '' # 1 + unknown_str = '' # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def decode(self, text_index, text_prob=None): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == self.end_idx: + break + if text_index[batch_idx][idx] in \ + [self.padding_idx, self.unknown_idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + if len(text) > 0: + result_list.append((text, np.mean(conf_list))) + else: + # here confidence of empty recog result is 1 + result_list.append(('', 1)) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob) + if label is None: + return text + label = self.decode(label) + return text, label diff --git a/tools/eval.py b/tools/eval.py index 3a25c266..f6fcf14c 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -28,7 +28,6 @@ from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model -from ppocr.utils.utility import print_dict import tools.program as program diff --git a/tools/export_model.py b/tools/export_model.py index 695af5c8..bd647fc7 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -55,6 +55,12 @@ def export_single_model(model, arch_config, save_path, logger): shape=[None, 3, 48, 160], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "PREN": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 64, 512], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": diff --git a/tools/program.py b/tools/program.py index e92bef33..d04a0e2e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -541,7 +541,7 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN' ] device = 'cpu' -- GitLab