diff --git a/README_cn.md b/README_cn.md index cc5cb00a38d97fd3dba46b30a76f2dc606e8d027..ebfc4b1d95b020269e5da635e2cf8074efd987c4 100644 --- a/README_cn.md +++ b/README_cn.md @@ -122,7 +122,10 @@ PaddleOCR开源的文本识别算法列表: - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) -- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研, coming soon) +- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研) + +*备注:* SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。 +原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。 参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -136,6 +139,7 @@ PaddleOCR开源的文本识别算法列表: |STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)| |RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)| |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| +|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)| 使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下: diff --git a/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml b/configs/rec/rec_r50fpn_vd_none_srn.yml similarity index 100% rename from configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml rename to configs/rec/rec_r50fpn_vd_none_srn.yml diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index 5f2963ace5fa70111a4c7610178921199ae389a2..575658ef0421e3cedf44cedea67caac47de0fdd5 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -25,7 +25,7 @@ class CharacterOps(object): def __init__(self, config): self.character_type = config['character_type'] self.loss_type = config['loss_type'] - self.max_text_len = config['max_text_length'] + self.max_text_len = 25 if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b51c49fcb90657d766bddb3c39b4339e6a342a2f..c81b4eb2560ee5ad66a85c96efe4de935a2beee1 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -40,7 +40,8 @@ class TextRecognizer(object): char_ops_params = { "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char + "use_space_char": args.use_space_char, + "max_text_length": args.max_text_length } if self.rec_algorithm != "RARE": char_ops_params['loss_type'] = 'ctc' diff --git a/tools/infer/utility.py b/tools/infer/utility.py index fc91880e2fdd39cb05d5115f5c3c76fd766b714c..fe590c7ec39df5eecdaf837df96bfd102590b62e 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -56,8 +56,8 @@ def parse_args(): #params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_image_shape", type=str, default="1, 64, 320") + parser.add_argument("--rec_char_type", type=str, default='en') parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument( "--rec_char_dict_path", diff --git a/tools/program.py b/tools/program.py index 6ebc27cb212ecfe1229dd5179f12514688909751..354cf8dd7cbdb35f9f2cafde988e6507471a9abf 100755 --- a/tools/program.py +++ b/tools/program.py @@ -194,13 +194,14 @@ def build(config, main_prog, startup_prog, mode): global_lr = optimizer._global_learning_rate() fetch_name_list.insert(0, "lr") fetch_varname_list.insert(0, global_lr.name) - if config['Global']["loss_type"] == 'srn': - model_average = fluid.optimizer.ModelAverage( - config['Global']['average_window'], - min_average_window=config['Global'][ - 'min_average_window'], - max_average_window=config['Global'][ - 'max_average_window']) + if "loss_type" in config["Global"]: + if config['Global']["loss_type"] == 'srn': + model_average = fluid.optimizer.ModelAverage( + config['Global']['average_window'], + min_average_window=config['Global'][ + 'min_average_window'], + max_average_window=config['Global'][ + 'max_average_window']) return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name, model_average)