diff --git a/tools/program.py b/tools/program.py index d941e717608f3d4cfba6cde075ae32e6968c5d94..18fead8ba81072202cf1ae13f32b21522a5e257f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -159,8 +159,7 @@ def train(config, eval_class, pre_best_model_dict, logger, - vdl_writer=None, - profiler_options=None): + vdl_writer=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) log_smooth_window = config['Global']['log_smooth_window'] @@ -168,6 +167,8 @@ def train(config, print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + profiler_options = config['profiler_options'] + global_step = 0 if 'global_step' in pre_best_model_dict: global_step = pre_best_model_dict['global_step'] @@ -405,6 +406,8 @@ def preprocess(is_train=False): profiler_options = FLAGS.profiler_options config = load_config(FLAGS.config) merge_config(FLAGS.opt) + profile_dic = {"profiler_options": FLAGS.profiler_options} + merge_config(profile_dic) # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] @@ -442,4 +445,4 @@ def preprocess(is_train=False): print_dict(config, logger) logger.info('train with paddle {} and device {}'.format(paddle.__version__, device)) - return config, device, logger, vdl_writer, profiler_options + return config, device, logger, vdl_writer diff --git a/tools/train.py b/tools/train.py index 17a12390405932262090942e2da1ac95e991d062..ee81961414c2321373c9d18061cb2d1daf4d8b98 100755 --- a/tools/train.py +++ b/tools/train.py @@ -41,7 +41,7 @@ import tools.program as program dist.get_world_size() -def main(config, device, logger, vdl_writer, profiler_options): +def main(config, device, logger, vdl_writer): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() @@ -105,8 +105,7 @@ def main(config, device, logger, vdl_writer, profiler_options): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, - profiler_options) + eval_class, pre_best_model_dict, logger, vdl_writer) def test_reader(config, device, logger): @@ -128,8 +127,8 @@ def test_reader(config, device, logger): if __name__ == '__main__': - config, device, logger, vdl_writer, profiler_options = program.preprocess( + config, device, logger, vdl_writer = program.preprocess( is_train=True) - main(config, device, logger, vdl_writer, profiler_options) - # test_reader(config, device, logger) + logger.info(f"config.profiler_options: {config.profiler_options}") + main(config, device, logger, vdl_writer) # test_reader(config, device, logger)