From 738366287c0b7b31359ba8acca1c2b0201c2c09b Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Wed, 12 Jun 2019 20:14:00 -0700 Subject: [PATCH] support train with ngraph (#158) * added a train file for ngraph * updte train.py instead --- BERT/train.py | 47 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/BERT/train.py b/BERT/train.py index c8dc115..aaf2a87 100644 --- a/BERT/train.py +++ b/BERT/train.py @@ -313,13 +313,18 @@ def train(args): exec_strategy.num_threads = dev_count exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope - train_exe = fluid.ParallelExecutor( - use_cuda=args.use_cuda, - loss_name=total_loss.name, - exec_strategy=exec_strategy, - main_program=train_program, - num_trainers=nccl2_num_trainers, - trainer_id=nccl2_trainer_id) + # use_ngraph is for CPU only, please refer to README_ngraph.md for details + use_ngraph = os.getenv('FLAGS_use_ngraph') + if not use_ngraph: + train_exe = fluid.ParallelExecutor( + use_cuda=args.use_cuda, + loss_name=total_loss.name, + exec_strategy=exec_strategy, + main_program=train_program, + num_trainers=nccl2_num_trainers, + trainer_id=nccl2_trainer_id) + else: + train_exe = exe if args.validation_set_dir and args.validation_set_dir != "": predict = predict_wrapper( @@ -345,17 +350,30 @@ def train(args): skip_steps = args.skip_steps * nccl2_num_trainers if nccl2_trainer_id != 0: - train_exe.run(fetch_list=[]) + if use_ngraph: + train_exe.run(fetch_list=[], program=train_program) + else: + train_exe.run(fetch_list=[]) continue if steps % skip_steps != 0: - train_exe.run(fetch_list=[]) + if use_ngraph: + train_exe.run(fetch_list=[], program=train_program) + else: + train_exe.run(fetch_list=[]) + else: - each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( - fetch_list=[ - next_sent_acc.name, mask_lm_loss.name, total_loss.name, - scheduled_lr.name - ]) + if use_ngraph: + each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( + fetch_list=[ + next_sent_acc.name, mask_lm_loss.name, total_loss.name, + scheduled_lr.name], program=train_program) + else: + each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( + fetch_list=[ + next_sent_acc.name, mask_lm_loss.name, total_loss.name, + scheduled_lr.name]) + acc.extend(each_next_acc) lm_cost.extend(each_mask_lm_cost) cost.extend(each_total_cost) @@ -398,7 +416,6 @@ def train(args): train_pyreader.reset() break - if __name__ == '__main__': print_arguments(args) if args.do_test: -- GitLab