提交 9cb30720 编写于 作者: T tink2123

fix bug and update doc

上级 d3ed210a
......@@ -122,7 +122,10 @@ PaddleOCR开源的文本识别算法列表:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研, coming soon)
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研)
*备注:* SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -136,6 +139,7 @@ PaddleOCR开源的文本识别算法列表:
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
......
......@@ -25,7 +25,7 @@ class CharacterOps(object):
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
self.max_text_len = 25
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
......
......@@ -40,7 +40,8 @@ class TextRecognizer(object):
char_ops_params = {
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
"use_space_char": args.use_space_char,
"max_text_length": args.max_text_length
}
if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
......
......@@ -56,8 +56,8 @@ def parse_args():
#params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_image_shape", type=str, default="1, 64, 320")
parser.add_argument("--rec_char_type", type=str, default='en')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument(
"--rec_char_dict_path",
......
......@@ -194,13 +194,14 @@ def build(config, main_prog, startup_prog, mode):
global_lr = optimizer._global_learning_rate()
fetch_name_list.insert(0, "lr")
fetch_varname_list.insert(0, global_lr.name)
if config['Global']["loss_type"] == 'srn':
model_average = fluid.optimizer.ModelAverage(
config['Global']['average_window'],
min_average_window=config['Global'][
'min_average_window'],
max_average_window=config['Global'][
'max_average_window'])
if "loss_type" in config["Global"]:
if config['Global']["loss_type"] == 'srn':
model_average = fluid.optimizer.ModelAverage(
config['Global']['average_window'],
min_average_window=config['Global'][
'min_average_window'],
max_average_window=config['Global'][
'max_average_window'])
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,
model_average)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册