diff --git a/BERT/train.py b/BERT/train.py index c8dc1157217a23fa58e20682c67cd75f9c7169ad..aaf2a8758beee71f3fb74cf5207bc06457a01350 100644 --- a/BERT/train.py +++ b/BERT/train.py @@ -313,13 +313,18 @@ def train(args): exec_strategy.num_threads = dev_count exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope - train_exe = fluid.ParallelExecutor( - use_cuda=args.use_cuda, - loss_name=total_loss.name, - exec_strategy=exec_strategy, - main_program=train_program, - num_trainers=nccl2_num_trainers, - trainer_id=nccl2_trainer_id) + # use_ngraph is for CPU only, please refer to README_ngraph.md for details + use_ngraph = os.getenv('FLAGS_use_ngraph') + if not use_ngraph: + train_exe = fluid.ParallelExecutor( + use_cuda=args.use_cuda, + loss_name=total_loss.name, + exec_strategy=exec_strategy, + main_program=train_program, + num_trainers=nccl2_num_trainers, + trainer_id=nccl2_trainer_id) + else: + train_exe = exe if args.validation_set_dir and args.validation_set_dir != "": predict = predict_wrapper( @@ -345,17 +350,30 @@ def train(args): skip_steps = args.skip_steps * nccl2_num_trainers if nccl2_trainer_id != 0: - train_exe.run(fetch_list=[]) + if use_ngraph: + train_exe.run(fetch_list=[], program=train_program) + else: + train_exe.run(fetch_list=[]) continue if steps % skip_steps != 0: - train_exe.run(fetch_list=[]) + if use_ngraph: + train_exe.run(fetch_list=[], program=train_program) + else: + train_exe.run(fetch_list=[]) + else: - each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( - fetch_list=[ - next_sent_acc.name, mask_lm_loss.name, total_loss.name, - scheduled_lr.name - ]) + if use_ngraph: + each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( + fetch_list=[ + next_sent_acc.name, mask_lm_loss.name, total_loss.name, + scheduled_lr.name], program=train_program) + else: + each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( + fetch_list=[ + next_sent_acc.name, mask_lm_loss.name, total_loss.name, + scheduled_lr.name]) + acc.extend(each_next_acc) lm_cost.extend(each_mask_lm_cost) cost.extend(each_total_cost) @@ -398,7 +416,6 @@ def train(args): train_pyreader.reset() break - if __name__ == '__main__': print_arguments(args) if args.do_test: