提交 73836628 编写于 作者: B baojun 提交者: tensor-tang

support train with ngraph (#158)

* added a train file for ngraph

* updte train.py instead
上级 6d9efb96
...@@ -313,6 +313,9 @@ def train(args): ...@@ -313,6 +313,9 @@ def train(args):
exec_strategy.num_threads = dev_count exec_strategy.num_threads = dev_count
exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
# use_ngraph is for CPU only, please refer to README_ngraph.md for details
use_ngraph = os.getenv('FLAGS_use_ngraph')
if not use_ngraph:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, use_cuda=args.use_cuda,
loss_name=total_loss.name, loss_name=total_loss.name,
...@@ -320,6 +323,8 @@ def train(args): ...@@ -320,6 +323,8 @@ def train(args):
main_program=train_program, main_program=train_program,
num_trainers=nccl2_num_trainers, num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id) trainer_id=nccl2_trainer_id)
else:
train_exe = exe
if args.validation_set_dir and args.validation_set_dir != "": if args.validation_set_dir and args.validation_set_dir != "":
predict = predict_wrapper( predict = predict_wrapper(
...@@ -345,17 +350,30 @@ def train(args): ...@@ -345,17 +350,30 @@ def train(args):
skip_steps = args.skip_steps * nccl2_num_trainers skip_steps = args.skip_steps * nccl2_num_trainers
if nccl2_trainer_id != 0: if nccl2_trainer_id != 0:
if use_ngraph:
train_exe.run(fetch_list=[], program=train_program)
else:
train_exe.run(fetch_list=[]) train_exe.run(fetch_list=[])
continue continue
if steps % skip_steps != 0: if steps % skip_steps != 0:
if use_ngraph:
train_exe.run(fetch_list=[], program=train_program)
else:
train_exe.run(fetch_list=[]) train_exe.run(fetch_list=[])
else: else:
if use_ngraph:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
fetch_list=[ fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name, next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name scheduled_lr.name], program=train_program)
]) else:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name])
acc.extend(each_next_acc) acc.extend(each_next_acc)
lm_cost.extend(each_mask_lm_cost) lm_cost.extend(each_mask_lm_cost)
cost.extend(each_total_cost) cost.extend(each_total_cost)
...@@ -398,7 +416,6 @@ def train(args): ...@@ -398,7 +416,6 @@ def train(args):
train_pyreader.reset() train_pyreader.reset()
break break
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
if args.do_test: if args.do_test:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册