diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index d0b75628c58833447333de36490141847f1815e4..98a38e7477f725c605c0cf017b6a7a4b469f7f3b 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -14,6 +14,7 @@ Global: character_type: en loss_type: ctc distort: true + debug: false reader_yml: ./configs/rec/rec_icdar15_reader.yml pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: diff --git a/tools/program.py b/tools/program.py index 870d27002f36bbed4b7a665f4ff9bc9cc420f0c1..6a51e5c37175f0f45e87571122adc2aba04d491c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -75,6 +75,8 @@ class AttrDict(dict): global_config = AttrDict() +default_config = {'Global': {'debug': False, }} + def load_config(file_path): """ @@ -85,6 +87,7 @@ def load_config(file_path): Returns: global config """ + merge_config(default_config) _, ext = os.path.splitext(file_path) assert ext in ['.yml', '.yaml'], "only support yaml files for now" merge_config(yaml.load(open(file_path), Loader=yaml.Loader)) diff --git a/tools/train.py b/tools/train.py index 15d6ebb2138ce19a2f65c7d1fabd56d86b7645be..c8350ff64b4894cb22bde063529786b0945dfea3 100755 --- a/tools/train.py +++ b/tools/train.py @@ -43,6 +43,7 @@ logger = initial_logger() from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model from ppocr.utils.character import CharacterOps +from paddle.fluid.contrib.model_stat import summary def main(): @@ -87,6 +88,14 @@ def main(): # compile program for multi-devices train_compile_program = program.create_multi_devices_program( train_program, train_opt_loss_name) + + # dump mode structure + if config['Global']['debug']: + if 'Attention' in config['Head'].keys(): + logger.warning('Does not suport dump attention...') + else: + summary(train_program) + init_model(config, train_program, exe) train_info_dict = {'compile_program':train_compile_program,\