From a3b291928b64687b5e6f437636eeeaa2e785f6e5 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Sun, 16 Aug 2020 16:46:22 +0800 Subject: [PATCH] polish code --- doc/doc_ch/config.md | 3 +++ ppocr/modeling/architectures/rec_model.py | 3 +++ tools/eval_utils/eval_rec_utils.py | 6 +++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/doc/doc_ch/config.md b/doc/doc_ch/config.md index 5e579096..03fe1b32 100644 --- a/doc/doc_ch/config.md +++ b/doc/doc_ch/config.md @@ -32,6 +32,9 @@ | loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention | | distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) | | use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 | +| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN | +| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目| +| min_average_window | 平均值计算窗口长度的最小值 | 10000 | \ | | reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ | | pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ | | checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 | diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index 5eacd5de..f4e3eea2 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -213,6 +213,9 @@ class RecModel(object): predict = predicts['predict'] if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) + if self.loss_type == "srn": + logger.infor( + "Warning! SRN does not support export model currently") return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: predict = predicts['predict'] diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 5a653678..4479d9df 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -69,7 +69,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): return_numpy=False) preds = np.array(outs[0]) - if preds.shape[1] != 1: + if config['Global']['loss_type'] == "attention": preds, preds_lod = convert_rec_attention_infer_res(preds) else: preds_lod = outs[0].lod()[0] @@ -123,8 +123,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode): def test_rec_benchmark(exe, config, eval_info_dict): " Evaluate lmdb dataset " - eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860','IC03_867', \ - 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80'] + eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \ + 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] eval_data_dir = config['TestReader']['lmdb_sets_dir'] total_evaluation_data_number = 0 total_correct_number = 0 -- GitLab