提交 9cf6c4e8 编写于 作者: L LDOUBLEV

fix profile_options

上级 8dd14799
......@@ -159,8 +159,7 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None,
profiler_options=None):
vdl_writer=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
log_smooth_window = config['Global']['log_smooth_window']
......@@ -168,6 +167,8 @@ def train(config,
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
profiler_options = config['profiler_options']
global_step = 0
if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step']
......@@ -405,6 +406,8 @@ def preprocess(is_train=False):
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
merge_config(profile_dic)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
......@@ -442,4 +445,4 @@ def preprocess(is_train=False):
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer, profiler_options
return config, device, logger, vdl_writer
......@@ -41,7 +41,7 @@ import tools.program as program
dist.get_world_size()
def main(config, device, logger, vdl_writer, profiler_options):
def main(config, device, logger, vdl_writer):
# init dist environment
if config['Global']['distributed']:
dist.init_parallel_env()
......@@ -105,8 +105,7 @@ def main(config, device, logger, vdl_writer, profiler_options):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer,
profiler_options)
eval_class, pre_best_model_dict, logger, vdl_writer)
def test_reader(config, device, logger):
......@@ -128,8 +127,8 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer, profiler_options = program.preprocess(
config, device, logger, vdl_writer = program.preprocess(
is_train=True)
main(config, device, logger, vdl_writer, profiler_options)
# test_reader(config, device, logger)
logger.info(f"config.profiler_options: {config.profiler_options}")
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册