diff --git a/tools/export_model.py b/tools/export_model.py index b6c03efba7b48f23ffb8794996972c79bee76c57..51c061788e65575a8a8c69ba60d42a6334b4ad5e 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -28,21 +28,15 @@ from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.utils.save_load import init_model from ppocr.utils.logging import get_logger -from tools.program import load_config - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", help="configuration file to use") - parser.add_argument( - "-o", "--output_path", type=str, default='./output/infer/') - return parser.parse_args() +from tools.program import load_config, merge_config,ArgsParser def main(): - FLAGS = parse_args() + FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) + merge_config(FLAGS.opt) logger = get_logger() + print(config) # build post process post_process_class = build_post_process(config['PostProcess'], @@ -57,7 +51,7 @@ def main(): init_model(config, model, logger) model.eval() - save_path = '{}/inference'.format(FLAGS.output_path) + save_path = '{}/inference'.format(config['Global']['save_inference_dir']) infer_shape = [3, 32, 100] if config['Architecture'][ 'model_type'] != "det" else [3, 640, 640] model = to_static(