diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml similarity index 99% rename from configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml rename to configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index a7221d7c4c39f6f37328e290b80577af407b6949..278e3265728a73655ac602c030d0ce53e8beb7d8 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -141,6 +141,7 @@ Train: img_mode: BGR channel_first: False - DetLabelEncode: # Class handling label + - CopyPaste: - IaaAugment: augmenter_args: - { 'type': Fliplr, 'args': { 'p': 0.5 } } diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml similarity index 99% rename from configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml rename to configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml index 9566a3f7b2dbda633d7333dc890565006d2c570c..4ba9487931eff605a3014c4b95aa23f432fff285 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml @@ -91,7 +91,7 @@ Optimizer: PostProcess: name: DistillationDBPostProcess - model_name: ["Student", "Student2"] + model_name: ["Student"] key: head_out thresh: 0.3 box_thresh: 0.6 diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml similarity index 100% rename from configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml rename to configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_student.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml similarity index 100% rename from configs/det/ch_PP-OCRv2/ch_PP-OCR_det_student.yml rename to configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml diff --git a/configs/det/det_r50_vd_sast_icdar15.yml b/configs/det/det_r50_vd_sast_icdar15.yml index dbfcefca964e73d42298fbbbc1e654b3bd809c77..e1bf6fad0c4f526c2b635494560446e779bdb572 100755 --- a/configs/det/det_r50_vd_sast_icdar15.yml +++ b/configs/det/det_r50_vd_sast_icdar15.yml @@ -8,7 +8,7 @@ Global: # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [4000, 5000] cal_metric_during_train: False - pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained checkpoints: save_inference_dir: use_visualdl: False @@ -106,4 +106,4 @@ Eval: shuffle: False drop_last: False batch_size_per_card: 1 # must be 1 - num_workers: 2 \ No newline at end of file + num_workers: 2 diff --git a/configs/det/det_r50_vd_sast_totaltext.yml b/configs/det/det_r50_vd_sast_totaltext.yml index 88dd31f3c21b184d956ad718dae808bb6054532e..557ff8bf0d120d704ab83da4d7e26d3a8562696e 100755 --- a/configs/det/det_r50_vd_sast_totaltext.yml +++ b/configs/det/det_r50_vd_sast_totaltext.yml @@ -8,7 +8,7 @@ Global: # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [4000, 5000] cal_metric_during_train: False - pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained checkpoints: save_inference_dir: use_visualdl: False @@ -105,4 +105,4 @@ Eval: shuffle: False drop_last: False batch_size_per_card: 1 # must be 1 - num_workers: 2 \ No newline at end of file + num_workers: 2 diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index a74e18d318699685400cc48430c04db3fef70b60..1a91ea95afb4ff91d3fd68fe0df6afaac9304661 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -1,29 +1,28 @@ Global: use_gpu: true - epoch_num: 50 + epoch_num: 400 log_smooth_window: 20 print_batch_step: 5 save_model_dir: ./output/table_mv3/ - save_epoch_step: 5 + save_epoch_step: 3 # evaluation is run every 400 iterations after the 0th iteration eval_batch_step: [0, 400] cal_metric_during_train: True - pretrained_model: + pretrained_model: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/table/table.jpg # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en max_text_length: 100 - max_elem_length: 500 + max_elem_length: 800 max_cell_num: 500 infer_mode: False process_total_num: 0 process_cut_num: 0 - Optimizer: name: Adam beta1: 0.9 @@ -41,13 +40,15 @@ Architecture: Backbone: name: MobileNetV3 scale: 1.0 - model_name: small - disable_se: True + model_name: large Head: name: TableAttentionHead hidden_size: 256 l2_decay: 0.00001 loc_type: 2 + max_text_length: 100 + max_elem_length: 800 + max_cell_num: 500 Loss: name: TableAttentionLoss diff --git a/deploy/hubserving/ocr_det/module.py b/deploy/hubserving/ocr_det/module.py index c7d253f5ec8d626279c9eb493e15d1c4c83cfbfd..19f528d187c6d2e379811909801ce4dec83105fc 100644 --- a/deploy/hubserving/ocr_det/module.py +++ b/deploy/hubserving/ocr_det/module.py @@ -18,7 +18,7 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_det import TextDetector from tools.infer.utility import parse_args -from deploy.hubserving.ocr_system.params import read_params +from deploy.hubserving.ocr_det.params import read_params @moduleinfo( diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 6f19c6aa0e204d5af2e785f91f67257576f3d6db..b4ec6bc2c32bc36f1e84cea583b4f0e7eb42e49d 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -50,7 +50,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] -- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))[13] 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -78,4 +78,3 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训 ## 3. 模型推理 上述模型中除PP-OCR系列模型以外,其余模型仅支持基于Python引擎的推理,具体内容可参考[基于Python预测引擎推理](./inference.md) - diff --git a/doc/doc_ch/reference.md b/doc/doc_ch/reference.md index f1337dedc96c685173cbcc8450a57c259d2c0029..3347447741a7852b30e793bd0d30696c190598a0 100644 --- a/doc/doc_ch/reference.md +++ b/doc/doc_ch/reference.md @@ -112,4 +112,14 @@ year={2016} } +13.NRTR +@misc{sheng2019nrtr, + title={NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition}, + author={Fenfen Sheng and Zhineng Chen and Bo Xu}, + year={2019}, + eprint={1806.00926}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + ``` diff --git a/ppocr/data/imaug/east_process.py b/ppocr/data/imaug/east_process.py index b1d7a5e51939af981dd62c269c930f4bf9ba4179..598b88daee938a159f3e73d7296ede9ec6f4bfcd 100644 --- a/ppocr/data/imaug/east_process.py +++ b/ppocr/data/imaug/east_process.py @@ -11,7 +11,10 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - +""" +This code is refered from: +https://github.com/songdejia/EAST/blob/master/data_utils.py +""" import math import cv2 import numpy as np @@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain'] class EASTProcessTrain(object): def __init__(self, - image_shape = [512, 512], - background_ratio = 0.125, - min_crop_side_ratio = 0.1, - min_text_size = 10, + image_shape=[512, 512], + background_ratio=0.125, + min_crop_side_ratio=0.1, + min_text_size=10, **kwargs): self.input_size = image_shape[1] self.random_scale = np.array([0.5, 1, 2.0, 3.0]) @@ -282,12 +285,7 @@ class EASTProcessTrain(object): 1.0 / max(min(poly_h, poly_w), 1.0) return score_map, geo_map, training_mask - def crop_area(self, - im, - polys, - tags, - crop_background=False, - max_tries=50): + def crop_area(self, im, polys, tags, crop_background=False, max_tries=50): """ make random crop from the input image :param im: @@ -436,4 +434,4 @@ class EASTProcessTrain(object): data['geo_map'] = geo_map data['training_mask'] = training_mask # print(im.shape, score_map.shape, geo_map.shape, training_mask.shape) - return data \ No newline at end of file + return data diff --git a/ppocr/data/imaug/iaa_augment.py b/ppocr/data/imaug/iaa_augment.py index 9ce6bd4209034389df04334a83717142ca8c7b40..f7553f0959f5c8ef2c46325ef73c2ffcc8c2b35e 100644 --- a/ppocr/data/imaug/iaa_augment.py +++ b/ppocr/data/imaug/iaa_augment.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/data/imaug/make_border_map.py b/ppocr/data/imaug/make_border_map.py index cc2c9034e147eb7bb6a70e43eda4903337a523f0..2eb76f6ab305f9132088600e116f4d0ec9d5348c 100644 --- a/ppocr/data/imaug/make_border_map.py +++ b/ppocr/data/imaug/make_border_map.py @@ -1,5 +1,20 @@ -# -*- coding:utf-8 -*- - +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/data/imaug/make_shrink_map.py b/ppocr/data/imaug/make_shrink_map.py index 15e8afa05bb9f7315a2e9342c78cb98718a54df9..8834b251b030471dcf3293c4e4d869a7625bfcde 100644 --- a/ppocr/data/imaug/make_shrink_map.py +++ b/ppocr/data/imaug/make_shrink_map.py @@ -1,5 +1,20 @@ -# -*- coding:utf-8 -*- - +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/data/imaug/random_crop_data.py b/ppocr/data/imaug/random_crop_data.py index 4d67cff61d6f340be6d80d8243c68909a94c4e88..6d02a584ed0b8447769cc56416c283804d06b02f 100644 --- a/ppocr/data/imaug/random_crop_data.py +++ b/ppocr/data/imaug/random_crop_data.py @@ -1,5 +1,20 @@ -# -*- coding:utf-8 -*- - +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/data/imaug/sast_process.py b/ppocr/data/imaug/sast_process.py index 1536dceb8ee5999226cfe7cf455d70e39b449530..08d03b194dcfab92ab59329857d4a1326531218e 100644 --- a/ppocr/data/imaug/sast_process.py +++ b/ppocr/data/imaug/sast_process.py @@ -11,7 +11,10 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - +""" +This part code is refered from: +https://github.com/songdejia/EAST/blob/master/data_utils.py +""" import math import cv2 import numpy as np diff --git a/ppocr/data/imaug/text_image_aug/augment.py b/ppocr/data/imaug/text_image_aug/augment.py index 1aeff3733a4521c56dd5972fc058f6e0c245e4b7..90d7e67f5a6c1dbcb3f58101992bbb8d68f4d48f 100644 --- a/ppocr/data/imaug/text_image_aug/augment.py +++ b/ppocr/data/imaug/text_image_aug/augment.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py +""" import numpy as np from .warp_mls import WarpMLS diff --git a/ppocr/data/imaug/text_image_aug/warp_mls.py b/ppocr/data/imaug/text_image_aug/warp_mls.py index d6cbe749b61aa4cf3163927c096868c83f4a4cdd..94d551af17f3dbe3d8ff891adf92f024dcb17f5b 100644 --- a/ppocr/data/imaug/text_image_aug/warp_mls.py +++ b/ppocr/data/imaug/text_image_aug/warp_mls.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py +""" import numpy as np diff --git a/ppocr/losses/det_basic_loss.py b/ppocr/losses/det_basic_loss.py index eba5526dd2bd1c0328130b50817172df437cc360..d11d5ef7ddcc870f2811a29391a632dc62574a6b 100644 --- a/ppocr/losses/det_basic_loss.py +++ b/ppocr/losses/det_basic_loss.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/losses/det_db_loss.py b/ppocr/losses/det_db_loss.py index b079aabff7c7deccc7e365b91c9407f7e894bcb9..708ffbdb47f349304e2bfd781a836e79348475f4 100755 --- a/ppocr/losses/det_db_loss.py +++ b/ppocr/losses/det_db_loss.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py +""" from __future__ import absolute_import from __future__ import division diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py index fe874fac1af439bfb47ba9050a61f02db302e224..858f8b915340fe24e8e510f15caf093bc36309c0 100644 --- a/ppocr/modeling/backbones/rec_mv1_enhance.py +++ b/ppocr/modeling/backbones/rec_mv1_enhance.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 4286d7691d1abcf80c283d1c1ab76f8cd1f4a634..6d77e42eb5def579052687ab6fdc265159311884 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -75,7 +75,7 @@ class AttentionHead(nn.Layer): probs_step, axis=1)], axis=1) next_input = probs_step.argmax(axis=1) targets = next_input - + probs = paddle.nn.functional.softmax(probs, axis=2) return probs diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 155f036d15673135eae9e5ee493648603609535d..e354f40d6518c1f7ca22e93694b1c6668fc003d2 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -23,32 +23,40 @@ import numpy as np class TableAttentionHead(nn.Layer): - def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): + def __init__(self, + in_channels, + hidden_size, + loc_type, + in_max_len=488, + max_text_length=100, + max_elem_length=800, + max_cell_num=500, + **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] self.hidden_size = hidden_size self.elem_num = 30 - self.max_text_length = 100 - self.max_elem_length = 500 - self.max_cell_num = 500 + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.elem_num, use_gru=False) self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.loc_type = loc_type self.in_max_len = in_max_len - + if self.loc_type == 1: self.loc_generator = nn.Linear(hidden_size, 4) else: if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) else: - self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) - + def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) return input_ont_hot @@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer): if len(fea.shape) == 3: pass else: - last_shape = int(np.prod(fea.shape[2:])) # gry added + last_shape = int(np.prod(fea.shape[2:])) # gry added fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] - + hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] if self.training and targets is not None: structure = targets[0] - for i in range(self.max_elem_length+1): + for i in range(self.max_elem_length + 1): elem_onehots = self._char_to_onehot( structure[:, i], onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer): alpha = None max_elem_length = paddle.to_tensor(self.max_elem_length) i = 0 - while i < max_elem_length+1: + while i < max_elem_length + 1: elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer): structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") i += 1 - + output = paddle.concat(output_hiddens, axis=1) structure_probs = self.structure_generator(output) structure_probs = F.softmax(structure_probs) @@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer): loc_concat = paddle.concat([output, loc_fea], axis=2) loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) - return {'structure_probs':structure_probs, 'loc_preds':loc_preds} + return {'structure_probs': structure_probs, 'loc_preds': loc_preds} + - class AttentionGRUCell(nn.Layer): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionGRUCell, self).__init__() diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index dcce6246ac64b4b84229cbd69a4dc53c658b4c7b..9bdab0f85112b90d8da959dce4e258188a812052 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py +""" from __future__ import absolute_import from __future__ import division @@ -231,7 +235,8 @@ class GridGenerator(nn.Layer): """ Return inv_delta_C which is needed to calculate T """ F = self.F hat_eye = paddle.eye(F, dtype='float64') # F x F - hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye + hat_C = paddle.norm( + C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index d9c9869dfcd35cb9b491db826f3bff5f766723f4..27b428ef2e73c9abf81d3881b23979343c8595b2 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refered from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -190,7 +193,8 @@ class DBPostProcess(object): class DistillationDBPostProcess(object): - def __init__(self, model_name=["student"], + def __init__(self, + model_name=["student"], key=None, thresh=0.3, box_thresh=0.6, @@ -201,12 +205,13 @@ class DistillationDBPostProcess(object): **kwargs): self.model_name = model_name self.key = key - self.post_process = DBPostProcess(thresh=thresh, - box_thresh=box_thresh, - max_candidates=max_candidates, - unclip_ratio=unclip_ratio, - use_dilation=use_dilation, - score_mode=score_mode) + self.post_process = DBPostProcess( + thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) def __call__(self, predicts, shape_list): results = {} diff --git a/ppocr/postprocess/locality_aware_nms.py b/ppocr/postprocess/locality_aware_nms.py index 53280cc13ed7e41859e23e2517938d4f6eb07076..d305ef681882b4a393a73190bcbd20a65d1f0c15 100644 --- a/ppocr/postprocess/locality_aware_nms.py +++ b/ppocr/postprocess/locality_aware_nms.py @@ -1,5 +1,6 @@ """ Locality aware nms. +This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py """ import numpy as np diff --git a/ppocr/utils/logging.py b/ppocr/utils/logging.py index 11896c37d9285e19a9526caa9c637d7eda7b1979..b55d5a773e96a4b0ac53b96048be22b461b2cf9b 100644 --- a/ppocr/utils/logging.py +++ b/ppocr/utils/logging.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/WenmuZhou/PytorchOCR/blob/master/torchocr/utils/logging.py +""" import os import sys import logging diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 103dd9bd8dcf0c8c3fcca1acfc5d01a3134e28a3..7a4e8392c4dad4c63bc161c7c34e8976db0f66ff 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -187,7 +187,7 @@ def create_predictor(args, mode, logger): "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2] } max_input_shape = { - "x": [1, 3, 2000, 2000], + "x": [1, 3, 1280, 1280], "conv2d_92.tmp_0": [1, 120, 400, 400], "conv2d_91.tmp_0": [1, 24, 200, 200], "conv2d_59.tmp_0": [1, 96, 400, 400], @@ -237,16 +237,16 @@ def create_predictor(args, mode, logger): opt_input_shape.update(opt_pact_shape) elif mode == "rec": min_input_shape = {"x": [1, 3, 32, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} + max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1024]} opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} elif mode == "cls": min_input_shape = {"x": [1, 3, 48, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} + max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]} opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} else: min_input_shape = {"x": [1, 3, 10, 10]} - max_input_shape = {"x": [1, 3, 1000, 1000]} - opt_input_shape = {"x": [1, 3, 500, 500]} + max_input_shape = {"x": [1, 3, 512, 512]} + opt_input_shape = {"x": [1, 3, 256, 256]} config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape)