提交 ce1b4a34 编写于 作者: T tink2123

add ano

上级 6dd494b6
...@@ -27,6 +27,14 @@ import numpy as np ...@@ -27,6 +27,14 @@ import numpy as np
class CTCPredict(object): class CTCPredict(object):
"""
CTC predict
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params): def __init__(self, params):
super(CTCPredict, self).__init__() super(CTCPredict, self).__init__()
self.char_num = params['char_num'] self.char_num = params['char_num']
......
...@@ -149,12 +149,16 @@ def cal_predicts_accuracy(char_ops, ...@@ -149,12 +149,16 @@ def cal_predicts_accuracy(char_ops,
Args: Args:
char_ops: CharacterOps char_ops: CharacterOps
preds: preds result,text index preds: preds result,text index
preds_lod: preds_lod: lod tensor of preds
labels: labels: label of input image, text index
labels_lod: labels_lod: lod tensor of label
is_remove_duplicate: is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return: Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
""" """
acc_num = 0 acc_num = 0
...@@ -178,6 +182,16 @@ def cal_predicts_accuracy(char_ops, ...@@ -178,6 +182,16 @@ def cal_predicts_accuracy(char_ops,
def convert_rec_attention_infer_res(preds): def convert_rec_attention_infer_res(preds):
"""
Convert recognition attention predict result with lod information
Args:
preds: the output of the model
Return:
convert_ids: A 1-D Tensor represents all the predicted results.
target_lod: The lod information of the predicted results
"""
img_num = preds.shape[0] img_num = preds.shape[0]
target_lod = [0] target_lod = [0]
convert_ids = [] convert_ids = []
...@@ -195,6 +209,16 @@ def convert_rec_attention_infer_res(preds): ...@@ -195,6 +209,16 @@ def convert_rec_attention_infer_res(preds):
def convert_rec_label_to_lod(ori_labels): def convert_rec_label_to_lod(ori_labels):
"""
Convert recognition label to lod tensor
Args:
ori_labels: origin labels of total images
Return:
convert_ids: A 1-D Tensor represents all labels
target_lod: The lod information of the labels
"""
img_num = len(ori_labels) img_num = len(ori_labels)
target_lod = [0] target_lod = [0]
convert_ids = [] convert_ids = []
......
...@@ -83,7 +83,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -83,7 +83,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def test_rec_benchmark(exe, config, eval_info_dict): def test_rec_benchmark(exe, config, eval_info_dict):
" 评估lmdb 数据" """
eval rec benchmark
"""
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \ eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
eval_data_dir = config['TestReader']['lmdb_sets_dir'] eval_data_dir = config['TestReader']['lmdb_sets_dir']
......
...@@ -35,6 +35,10 @@ from ppocr.utils.character import cal_predicts_accuracy ...@@ -35,6 +35,10 @@ from ppocr.utils.character import cal_predicts_accuracy
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
"""
Parase arguments
"""
def __init__(self): def __init__(self):
super(ArgsParser, self).__init__( super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter) formatter_class=RawDescriptionHelpFormatter)
...@@ -61,7 +65,9 @@ class ArgsParser(ArgumentParser): ...@@ -61,7 +65,9 @@ class ArgsParser(ArgumentParser):
class AttrDict(dict): class AttrDict(dict):
"""Single level attribute dict, NOT recursive""" """
Single level attribute dict, NOT recursive
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(AttrDict, self).__init__() super(AttrDict, self).__init__()
...@@ -146,21 +152,22 @@ def check_gpu(use_gpu): ...@@ -146,21 +152,22 @@ def check_gpu(use_gpu):
def build(config, main_prog, startup_prog, mode): def build(config, main_prog, startup_prog, mode):
""" """
Build a program using a model and an optimizer Build a program using a model and an optimizer
1. create feeds 1. create a dataloader
2. create a dataloader 2. create a model
3. create a model 3. create fetchs
4. create fetchs 4. create an optimizer
5. create an optimizer
Args: Args:
config(dict): config config(dict): config
main_prog(): main program main_prog(): main program
startup_prog(): startup program startup_prog(): startup program
is_train(bool): train or valid mode(str): train or valid
Returns: Returns:
dataloader(): a bridge between the model and the data dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures) fetch_name_list(dict): dict of model outputs(included loss and measures)
fetch_varname_list(list): list of outputs' varname
opt_loss_name(str): name of loss
""" """
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -185,6 +192,19 @@ def build(config, main_prog, startup_prog, mode): ...@@ -185,6 +192,19 @@ def build(config, main_prog, startup_prog, mode):
def build_export(config, main_prog, startup_prog): def build_export(config, main_prog, startup_prog):
""" """
Build a program for export model
1. create a model
2. create fetchs
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
Returns:
feeded_var_names(list): list of feeded var names
target_vars(list): list of output[fetches_var]
fetches_var_name(list): list of fetch var name
""" """
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -212,6 +232,16 @@ def create_multi_devices_program(program, loss_var_name): ...@@ -212,6 +232,16 @@ def create_multi_devices_program(program, loss_var_name):
def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
"""
Feed data to the model and fetch the measures and loss for detection
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0 train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
...@@ -277,6 +307,16 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): ...@@ -277,6 +307,16 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
"""
Feed data to the model and fetch the measures and loss for recognition
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0 train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册