diff --git a/tools/eval.py b/tools/eval.py index a5b8e26a52602250b6b4529fdc54074c53263138..edd84a9dd98133c1e700eb35561a3ab287bd3162 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -45,26 +45,10 @@ from ppocr.utils.save_load import init_model from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import test_rec_benchmark from eval_utils.eval_rec_utils import eval_rec_run -from ppocr.utils.character import CharacterOps def main(): - config = program.load_config(FLAGS.config) - program.merge_config(FLAGS.opt) - logger.info(config) - - # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] - program.check_gpu(use_gpu) - - alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] - if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: - config['Global']['char_ops'] = CharacterOps(config['Global']) - - place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() - startup_prog = fluid.Program() - eval_program = fluid.Program() + startup_prog, eval_program, place, config, train_alg_type = program.preprocess() eval_build_outputs = program.build( config, eval_program, startup_prog, mode='test') eval_fetch_name_list = eval_build_outputs[1] @@ -75,7 +59,7 @@ def main(): init_model(config, eval_program, exe) - if alg in ['EAST', 'DB']: + if train_alg_type == 'det': eval_reader = reader_main(config=config, mode="eval") eval_info_dict = {'program':eval_program,\ 'reader':eval_reader,\ @@ -101,6 +85,4 @@ def main(): if __name__ == '__main__': - parser = program.ArgsParser() - FLAGS = parser.parse_args() main() diff --git a/tools/export_model.py b/tools/export_model.py index 4415eda84048395da094c0bc6e5685979f94a2e8..de4ba0e4c44fec1cd2427bfe7c9065639eef26e2 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -41,27 +41,11 @@ from paddle import fluid from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.utils.save_load import init_model -from ppocr.utils.character import CharacterOps -from ppocr.utils.utility import create_module -def main(): - config = program.load_config(FLAGS.config) - program.merge_config(FLAGS.opt) - logger.info(config) - - # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] - program.check_gpu(use_gpu) - alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] - if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: - config['Global']['char_ops'] = CharacterOps(config['Global']) - - place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() - startup_prog = fluid.Program() - eval_program = fluid.Program() +def main(): + startup_prog, eval_program, place, config, _ = program.preprocess() feeded_var_names, target_vars, fetches_var_name = program.build_export( config, eval_program, startup_prog) @@ -88,6 +72,4 @@ def main(): if __name__ == '__main__': - parser = program.ArgsParser() - FLAGS = parser.parse_args() main() diff --git a/tools/program.py b/tools/program.py index ff8743f15e2925a92fff76e7430e761d7baa720e..4ebc11670702ba627c89b060692c9827e6e163fd 100755 --- a/tools/program.py +++ b/tools/program.py @@ -22,6 +22,7 @@ import yaml import os from ppocr.utils.utility import create_module from ppocr.utils.utility import initial_logger + logger = initial_logger() import paddle.fluid as fluid @@ -31,8 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run from ppocr.utils.save_load import save_model import numpy as np -from ppocr.utils.character import cal_predicts_accuracy - +from ppocr.utils.character import cal_predicts_accuracy, CharacterOps class ArgsParser(ArgumentParser): def __init__(self): @@ -374,3 +374,29 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_model(train_info_dict['train_program'], save_path) return + +def preprocess(): + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + logger.info(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + check_gpu(use_gpu) + + alg = config['Global']['algorithm'] + assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] + if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: + config['Global']['char_ops'] = CharacterOps(config['Global']) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + startup_program = fluid.Program() + train_program = fluid.Program() + + if alg in ['EAST', 'DB']: + train_alg_type = 'det' + else: + train_alg_type = 'rec' + + return startup_program, train_program, place, config, train_alg_type diff --git a/tools/train.py b/tools/train.py index 29205483cc65ade1fdea1ea4cdb711369170cd32..0f5e9039cfd4cc23e418434323b74c6612587ed2 100755 --- a/tools/train.py +++ b/tools/train.py @@ -42,27 +42,10 @@ from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model -from ppocr.utils.character import CharacterOps from paddle.fluid.contrib.model_stat import summary def main(): - config = program.load_config(FLAGS.config) - program.merge_config(FLAGS.opt) - logger.info(config) - - # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] - program.check_gpu(use_gpu) - - alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] - if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: - config['Global']['char_ops'] = CharacterOps(config['Global']) - - place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() - startup_program = fluid.Program() - train_program = fluid.Program() train_build_outputs = program.build( config, train_program, startup_program, mode='train') train_loader = train_build_outputs[0] @@ -109,15 +92,13 @@ def main(): 'fetch_name_list':eval_fetch_name_list,\ 'fetch_varname_list':eval_fetch_varname_list} - if alg in ['EAST', 'DB']: + if train_alg_type == 'det': program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) else: program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) def test_reader(): - config = program.load_config(FLAGS.config) - program.merge_config(FLAGS.opt) logger.info(config) train_reader = reader_main(config=config, mode="train") import time @@ -136,7 +117,6 @@ def test_reader(): if __name__ == '__main__': - parser = program.ArgsParser() - FLAGS = parser.parse_args() + startup_program, train_program, place, config, train_alg_type = program.preprocess() main() # test_reader()