提交 9cf6c4e8 编写于 作者: L LDOUBLEV

fix profile_options

上级 8dd14799
...@@ -159,8 +159,7 @@ def train(config, ...@@ -159,8 +159,7 @@ def train(config,
eval_class, eval_class,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None, vdl_writer=None):
profiler_options=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
...@@ -168,6 +167,8 @@ def train(config, ...@@ -168,6 +167,8 @@ def train(config,
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
profiler_options = config['profiler_options']
global_step = 0 global_step = 0
if 'global_step' in pre_best_model_dict: if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step'] global_step = pre_best_model_dict['global_step']
...@@ -405,6 +406,8 @@ def preprocess(is_train=False): ...@@ -405,6 +406,8 @@ def preprocess(is_train=False):
profiler_options = FLAGS.profiler_options profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
merge_config(profile_dic)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
...@@ -442,4 +445,4 @@ def preprocess(is_train=False): ...@@ -442,4 +445,4 @@ def preprocess(is_train=False):
print_dict(config, logger) print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__, logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device)) device))
return config, device, logger, vdl_writer, profiler_options return config, device, logger, vdl_writer
...@@ -41,7 +41,7 @@ import tools.program as program ...@@ -41,7 +41,7 @@ import tools.program as program
dist.get_world_size() dist.get_world_size()
def main(config, device, logger, vdl_writer, profiler_options): def main(config, device, logger, vdl_writer):
# init dist environment # init dist environment
if config['Global']['distributed']: if config['Global']['distributed']:
dist.init_parallel_env() dist.init_parallel_env()
...@@ -105,8 +105,7 @@ def main(config, device, logger, vdl_writer, profiler_options): ...@@ -105,8 +105,7 @@ def main(config, device, logger, vdl_writer, profiler_options):
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, eval_class, pre_best_model_dict, logger, vdl_writer)
profiler_options)
def test_reader(config, device, logger): def test_reader(config, device, logger):
...@@ -128,8 +127,8 @@ def test_reader(config, device, logger): ...@@ -128,8 +127,8 @@ def test_reader(config, device, logger):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer, profiler_options = program.preprocess( config, device, logger, vdl_writer = program.preprocess(
is_train=True) is_train=True)
main(config, device, logger, vdl_writer, profiler_options) logger.info(f"config.profiler_options: {config.profiler_options}")
# test_reader(config, device, logger) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger) # test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册