提交 ce1b4a34 编写于 作者: T tink2123

add ano

上级 6dd494b6
......@@ -27,6 +27,14 @@ import numpy as np
class CTCPredict(object):
"""
CTC predict
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params):
super(CTCPredict, self).__init__()
self.char_num = params['char_num']
......
......@@ -149,12 +149,16 @@ def cal_predicts_accuracy(char_ops,
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod:
labels:
labels_lod:
is_remove_duplicate:
preds_lod: lod tensor of preds
labels: label of input image, text index
labels_lod: lod tensor of label
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
"""
acc_num = 0
......@@ -178,6 +182,16 @@ def cal_predicts_accuracy(char_ops,
def convert_rec_attention_infer_res(preds):
"""
Convert recognition attention predict result with lod information
Args:
preds: the output of the model
Return:
convert_ids: A 1-D Tensor represents all the predicted results.
target_lod: The lod information of the predicted results
"""
img_num = preds.shape[0]
target_lod = [0]
convert_ids = []
......@@ -195,6 +209,16 @@ def convert_rec_attention_infer_res(preds):
def convert_rec_label_to_lod(ori_labels):
"""
Convert recognition label to lod tensor
Args:
ori_labels: origin labels of total images
Return:
convert_ids: A 1-D Tensor represents all labels
target_lod: The lod information of the labels
"""
img_num = len(ori_labels)
target_lod = [0]
convert_ids = []
......
......@@ -83,7 +83,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def test_rec_benchmark(exe, config, eval_info_dict):
" 评估lmdb 数据"
"""
eval rec benchmark
"""
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
eval_data_dir = config['TestReader']['lmdb_sets_dir']
......
......@@ -35,6 +35,10 @@ from ppocr.utils.character import cal_predicts_accuracy
class ArgsParser(ArgumentParser):
"""
Parase arguments
"""
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
......@@ -61,7 +65,9 @@ class ArgsParser(ArgumentParser):
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
"""
Single level attribute dict, NOT recursive
"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
......@@ -146,21 +152,22 @@ def check_gpu(use_gpu):
def build(config, main_prog, startup_prog, mode):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
1. create a dataloader
2. create a model
3. create fetchs
4. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
mode(str): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
fetch_name_list(dict): dict of model outputs(included loss and measures)
fetch_varname_list(list): list of outputs' varname
opt_loss_name(str): name of loss
"""
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
......@@ -185,6 +192,19 @@ def build(config, main_prog, startup_prog, mode):
def build_export(config, main_prog, startup_prog):
"""
Build a program for export model
1. create a model
2. create fetchs
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
Returns:
feeded_var_names(list): list of feeded var names
target_vars(list): list of output[fetches_var]
fetches_var_name(list): list of fetch var name
"""
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
......@@ -212,6 +232,16 @@ def create_multi_devices_program(program, loss_var_name):
def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
"""
Feed data to the model and fetch the measures and loss for detection
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
......@@ -277,6 +307,16 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
"""
Feed data to the model and fetch the measures and loss for recognition
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册