From d89d3384a0de021053bed91c92a84b850aebdff8 Mon Sep 17 00:00:00 2001 From: lyl120117 <278401555@qq.com> Date: Sat, 11 Jul 2020 12:14:05 +0800 Subject: [PATCH] dump model structure --- configs/rec/rec_icdar15_train.yml | 1 + tools/program.py | 3 +++ tools/train.py | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index d0b75628..98a38e74 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -14,6 +14,7 @@ Global: character_type: en loss_type: ctc distort: true + debug: false reader_yml: ./configs/rec/rec_icdar15_reader.yml pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: diff --git a/tools/program.py b/tools/program.py index 870d2700..6a51e5c3 100755 --- a/tools/program.py +++ b/tools/program.py @@ -75,6 +75,8 @@ class AttrDict(dict): global_config = AttrDict() +default_config = {'Global': {'debug': False, }} + def load_config(file_path): """ @@ -85,6 +87,7 @@ def load_config(file_path): Returns: global config """ + merge_config(default_config) _, ext = os.path.splitext(file_path) assert ext in ['.yml', '.yaml'], "only support yaml files for now" merge_config(yaml.load(open(file_path), Loader=yaml.Loader)) diff --git a/tools/train.py b/tools/train.py index 15d6ebb2..c8350ff6 100755 --- a/tools/train.py +++ b/tools/train.py @@ -43,6 +43,7 @@ logger = initial_logger() from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model from ppocr.utils.character import CharacterOps +from paddle.fluid.contrib.model_stat import summary def main(): @@ -87,6 +88,14 @@ def main(): # compile program for multi-devices train_compile_program = program.create_multi_devices_program( train_program, train_opt_loss_name) + + # dump mode structure + if config['Global']['debug']: + if 'Attention' in config['Head'].keys(): + logger.warning('Does not suport dump attention...') + else: + summary(train_program) + init_model(config, train_program, exe) train_info_dict = {'compile_program':train_compile_program,\ -- GitLab