diff --git a/configs/rec/rec_mv3_tps_bilstm_att.yml b/configs/rec/rec_mv3_tps_bilstm_att.yml new file mode 100644 index 0000000000000000000000000000000000000000..c64b2ccc266b67263e291d03d711c58cd9403273 --- /dev/null +++ b/configs/rec/rec_mv3_tps_bilstm_att.yml @@ -0,0 +1,102 @@ +Global: + use_gpu: true + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_mv3_tps_bilstm_att/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0.00001 + +Architecture: + model_type: rec + algorithm: RARE + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 0.1 + model_name: small + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 96 + Head: + name: AttentionHead + hidden_size: 96 + + +Loss: + name: AttentionLoss + +PostProcess: + name: AttnLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ../training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ../validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 1 diff --git a/configs/rec/rec_r34_vd_tps_bilstm_att.yml b/configs/rec/rec_r34_vd_tps_bilstm_att.yml new file mode 100644 index 0000000000000000000000000000000000000000..7be34b9c55b628859a98200377336d5ca375aedb --- /dev/null +++ b/configs/rec/rec_r34_vd_tps_bilstm_att.yml @@ -0,0 +1,101 @@ +Global: + use_gpu: true + epoch_num: 400 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/b3_rare_r34_none_gru/ + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0.00000 + +Architecture: + model_type: rec + algorithm: RARE + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 0.1 + model_name: large + Backbone: + name: ResNet + layers: 34 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 256 #96 + Head: + name: AttentionHead # AttentionHead + hidden_size: 256 # + l2_decay: 0.00001 + +Loss: + name: AttentionLoss + +PostProcess: + name: AttnLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ../training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ../validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - AttnLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 8 diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index abbc5da4c21cf89466a5faef6cf6cb0c1eb14d13..4ff7482c932d1a32567b4b1e4b94135d339075e7 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -40,7 +40,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐) - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] -- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon +- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -53,6 +53,9 @@ PaddleOCR基于动态图开源的文本识别算法列表: |CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| +|RARE|MobileNetV3|82.5|rec_mv3_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| +|RARE|Resnet34_vd|83.6|rec_r34_vd_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | + PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index bc877ab78c583f04dd0bf740712457094325d00e..f36e801988d9d34beb09c4ee9caebccf5ce9c776 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -201,6 +201,8 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | +| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 7d7896e7144c9d2db28b29a6f16b23677925d67a..423fe807b5e8329f0aed4e56f3141b453b083899 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -42,7 +42,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7] - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] -- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon +- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -55,6 +55,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| +|RARE|MobileNetV3|82.5|rec_mv3_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| +|RARE|Resnet34_vd|83.6|rec_r34_vd_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index f29703d14454ce979ad4f7d8cda0d2768721b53d..c2ff20226855358260b0c1ec3e9838899c20a4e6 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -195,8 +195,11 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | +| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | + For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: co diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 61c0c196b3a48911707dc5210a410145ec93a76d..26ac4d818634f83ebbc160d593b73a5684776170 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -199,16 +199,30 @@ class AttnLabelEncode(BaseRecLabelEncode): super(AttnLabelEncode, self).__init__(max_text_length, character_dict_path, character_type, use_space_char) - self.beg_str = "sos" - self.end_str = "eos" def add_special_char(self, dict_character): - dict_character = [self.beg_str, self.end_str] + dict_character + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character - def __call__(self, text): + def __call__(self, data): + text = data['label'] text = self.encode(text) - return text + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len + - len(text) - 1) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] def get_beg_end_flag_idx(self, beg_or_end): if beg_or_end == "beg": diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index b280eb333e8910ac2378962eb4ffde5a98f31efd..3881abf7741b8be78306bd070afb11df15606327 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -23,13 +23,15 @@ def build_loss(config): # rec loss from .rec_ctc_loss import CTCLoss + from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss # cls loss from .cls_loss import ClsLoss support_dict = [ - 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss' + 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', + 'SRNLoss' ] config = copy.deepcopy(config) diff --git a/ppocr/losses/rec_att_loss.py b/ppocr/losses/rec_att_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6e2f67483c86a45f3aa1feb1e1fac1a5013bfb46 --- /dev/null +++ b/ppocr/losses/rec_att_loss.py @@ -0,0 +1,39 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class AttentionLoss(nn.Layer): + def __init__(self, **kwargs): + super(AttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[ + 1], predicts.shape[2] + assert len(targets.shape) == len(list(predicts.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) + targets = paddle.reshape(targets, [-1]) + + return {'loss': paddle.sum(self.loss_func(inputs, targets))} diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 1a39ca412a1faf9a8cefb1de0db66c33ed9dc27e..efe05718506e94a5ae8ad5ff47bcff26d44c1473 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,12 +23,14 @@ def build_head(config): # rec head from .rec_ctc_head import CTCHead + from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead # cls head from .cls_head import ClsHead support_dict = [ - 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead' + 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', + 'SRNHead' ] module_name = config.pop('name') diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a7cfe1282141d4646bf3c410d4b0f9a3e94d28fb --- /dev/null +++ b/ppocr/modeling/heads/rec_att_head.py @@ -0,0 +1,199 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +class AttentionHead(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionHead, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionGRUCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = paddle.zeros((batch_size, self.hidden_size)) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + (outputs, hidden), alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(outputs) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + next_input = probs_step.argmax(axis=1) + targets = next_input + + return probs + + +class AttentionGRUCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionGRUCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) + + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha + + +class AttentionLSTM(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionLSTM, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( + (batch_size, self.hidden_size))) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + # one-hot vectors for a i-th char + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + + hidden = (hidden[1][0], hidden[1][1]) + output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(hidden[0]) + hidden = (hidden[1][0], hidden[1][1]) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + + next_input = probs_step.argmax(axis=1) + + targets = next_input + + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 76a700e1599b143288814778dcc948126a98151d..af243caa44e8390657b7a95e971aede0c0f90edd 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -135,16 +135,62 @@ class AttnLabelDecode(BaseRecLabelDecode): **kwargs): super(AttnLabelDecode, self).__init__(character_dict_path, character_type, use_space_char) - self.beg_str = "sos" - self.end_str = "eos" def add_special_char(self, dict_character): - dict_character = [self.beg_str, self.end_str] + dict_character + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character - def __call__(self, text): + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + [beg_idx, end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ text = self.decode(text) - return text + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label def get_ignored_tokens(self): beg_idx = self.get_beg_end_flag_idx("beg") diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 8c4f9214db9621fe4e0393ed3dac0e9a7ccedbf6..de7ee9d342063161f2e329c99d2428051c0ecf8c 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -184,4 +184,4 @@ def main(args): if __name__ == "__main__": - main(utility.parse_args()) \ No newline at end of file + main(utility.parse_args())