From 9cf6c4e8340b16c568721db914ab05c79216dbdd Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 29 Sep 2021 01:59:43 +0000 Subject: [PATCH] fix profile_options --- tools/program.py | 9 ++++++--- tools/train.py | 11 +++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tools/program.py b/tools/program.py index d941e717..18fead8b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -159,8 +159,7 @@ def train(config, eval_class, pre_best_model_dict, logger, - vdl_writer=None, - profiler_options=None): + vdl_writer=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) log_smooth_window = config['Global']['log_smooth_window'] @@ -168,6 +167,8 @@ def train(config, print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + profiler_options = config['profiler_options'] + global_step = 0 if 'global_step' in pre_best_model_dict: global_step = pre_best_model_dict['global_step'] @@ -405,6 +406,8 @@ def preprocess(is_train=False): profiler_options = FLAGS.profiler_options config = load_config(FLAGS.config) merge_config(FLAGS.opt) + profile_dic = {"profiler_options": FLAGS.profiler_options} + merge_config(profile_dic) # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] @@ -442,4 +445,4 @@ def preprocess(is_train=False): print_dict(config, logger) logger.info('train with paddle {} and device {}'.format(paddle.__version__, device)) - return config, device, logger, vdl_writer, profiler_options + return config, device, logger, vdl_writer diff --git a/tools/train.py b/tools/train.py index 17a12390..ee819614 100755 --- a/tools/train.py +++ b/tools/train.py @@ -41,7 +41,7 @@ import tools.program as program dist.get_world_size() -def main(config, device, logger, vdl_writer, profiler_options): +def main(config, device, logger, vdl_writer): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() @@ -105,8 +105,7 @@ def main(config, device, logger, vdl_writer, profiler_options): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, - profiler_options) + eval_class, pre_best_model_dict, logger, vdl_writer) def test_reader(config, device, logger): @@ -128,8 +127,8 @@ def test_reader(config, device, logger): if __name__ == '__main__': - config, device, logger, vdl_writer, profiler_options = program.preprocess( + config, device, logger, vdl_writer = program.preprocess( is_train=True) - main(config, device, logger, vdl_writer, profiler_options) - # test_reader(config, device, logger) + logger.info(f"config.profiler_options: {config.profiler_options}") + main(config, device, logger, vdl_writer) # test_reader(config, device, logger) -- GitLab