From ce1b4a348f04504c196931c84bd22cd04e21c7a2 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 29 May 2020 15:51:13 +0800 Subject: [PATCH] add ano --- ppocr/modeling/heads/rec_ctc_head.py | 8 ++++ ppocr/utils/character.py | 32 ++++++++++++++-- tools/eval_utils/eval_rec_utils.py | 4 +- tools/program.py | 56 ++++++++++++++++++++++++---- 4 files changed, 87 insertions(+), 13 deletions(-) diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 37b4b00f..5773adce 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -27,6 +27,14 @@ import numpy as np class CTCPredict(object): + """ + CTC predict + + Args: + params(object): Params from yaml file and settings from command line + + """ + def __init__(self, params): super(CTCPredict, self).__init__() self.char_num = params['char_num'] diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index 42b63173..13ba079b 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -149,12 +149,16 @@ def cal_predicts_accuracy(char_ops, Args: char_ops: CharacterOps preds: preds result,text index - preds_lod: - labels: - labels_lod: - is_remove_duplicate: + preds_lod: lod tensor of preds + labels: label of input image, text index + labels_lod: lod tensor of label + is_remove_duplicate: Whether to remove duplicate characters, + The default is False Return: + acc: The accuracy of test set + acc_num: The correct number of samples predicted + img_num: The total sample number of the test set """ acc_num = 0 @@ -178,6 +182,16 @@ def cal_predicts_accuracy(char_ops, def convert_rec_attention_infer_res(preds): + """ + Convert recognition attention predict result with lod information + + Args: + preds: the output of the model + + Return: + convert_ids: A 1-D Tensor represents all the predicted results. + target_lod: The lod information of the predicted results + """ img_num = preds.shape[0] target_lod = [0] convert_ids = [] @@ -195,6 +209,16 @@ def convert_rec_attention_infer_res(preds): def convert_rec_label_to_lod(ori_labels): + """ + Convert recognition label to lod tensor + + Args: + ori_labels: origin labels of total images + Return: + convert_ids: A 1-D Tensor represents all labels + target_lod: The lod information of the labels + + """ img_num = len(ori_labels) target_lod = [0] convert_ids = [] diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 2d7d7e1d..d316ae51 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -83,7 +83,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode): def test_rec_benchmark(exe, config, eval_info_dict): - " 评估lmdb 数据" + """ + eval rec benchmark + """ eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \ 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] eval_data_dir = config['TestReader']['lmdb_sets_dir'] diff --git a/tools/program.py b/tools/program.py index a114b1cb..b50d154b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -35,6 +35,10 @@ from ppocr.utils.character import cal_predicts_accuracy class ArgsParser(ArgumentParser): + """ + Parase arguments + """ + def __init__(self): super(ArgsParser, self).__init__( formatter_class=RawDescriptionHelpFormatter) @@ -61,7 +65,9 @@ class ArgsParser(ArgumentParser): class AttrDict(dict): - """Single level attribute dict, NOT recursive""" + """ + Single level attribute dict, NOT recursive + """ def __init__(self, **kwargs): super(AttrDict, self).__init__() @@ -146,21 +152,22 @@ def check_gpu(use_gpu): def build(config, main_prog, startup_prog, mode): """ Build a program using a model and an optimizer - 1. create feeds - 2. create a dataloader - 3. create a model - 4. create fetchs - 5. create an optimizer + 1. create a dataloader + 2. create a model + 3. create fetchs + 4. create an optimizer Args: config(dict): config main_prog(): main program startup_prog(): startup program - is_train(bool): train or valid + mode(str): train or valid Returns: dataloader(): a bridge between the model and the data - fetchs(dict): dict of model outputs(included loss and measures) + fetch_name_list(dict): dict of model outputs(included loss and measures) + fetch_varname_list(list): list of outputs' varname + opt_loss_name(str): name of loss """ with fluid.program_guard(main_prog, startup_prog): with fluid.unique_name.guard(): @@ -185,6 +192,19 @@ def build(config, main_prog, startup_prog, mode): def build_export(config, main_prog, startup_prog): """ + Build a program for export model + 1. create a model + 2. create fetchs + + Args: + config(dict): config + main_prog(): main program + startup_prog(): startup program + + Returns: + feeded_var_names(list): list of feeded var names + target_vars(list): list of output[fetches_var] + fetches_var_name(list): list of fetch var name """ with fluid.program_guard(main_prog, startup_prog): with fluid.unique_name.guard(): @@ -212,6 +232,16 @@ def create_multi_devices_program(program, loss_var_name): def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): + """ + Feed data to the model and fetch the measures and loss for detection + + Args: + config: config + exe: + train_info_dict: information dict for training + eval_info_dict: information dict for evaluation + + """ train_batch_id = 0 log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num'] @@ -277,6 +307,16 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): + """ + Feed data to the model and fetch the measures and loss for recognition + + Args: + config: config + exe: + train_info_dict: information dict for training + eval_info_dict: information dict for evaluation + + """ train_batch_id = 0 log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num'] -- GitLab