提交 73836628 编写于 作者: B baojun 提交者: tensor-tang

support train with ngraph (#158)

* added a train file for ngraph

* updte train.py instead
上级 6d9efb96
......@@ -313,13 +313,18 @@ def train(args):
exec_strategy.num_threads = dev_count
exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
loss_name=total_loss.name,
exec_strategy=exec_strategy,
main_program=train_program,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
# use_ngraph is for CPU only, please refer to README_ngraph.md for details
use_ngraph = os.getenv('FLAGS_use_ngraph')
if not use_ngraph:
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
loss_name=total_loss.name,
exec_strategy=exec_strategy,
main_program=train_program,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
else:
train_exe = exe
if args.validation_set_dir and args.validation_set_dir != "":
predict = predict_wrapper(
......@@ -345,17 +350,30 @@ def train(args):
skip_steps = args.skip_steps * nccl2_num_trainers
if nccl2_trainer_id != 0:
train_exe.run(fetch_list=[])
if use_ngraph:
train_exe.run(fetch_list=[], program=train_program)
else:
train_exe.run(fetch_list=[])
continue
if steps % skip_steps != 0:
train_exe.run(fetch_list=[])
if use_ngraph:
train_exe.run(fetch_list=[], program=train_program)
else:
train_exe.run(fetch_list=[])
else:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name
])
if use_ngraph:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name], program=train_program)
else:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name])
acc.extend(each_next_acc)
lm_cost.extend(each_mask_lm_cost)
cost.extend(each_total_cost)
......@@ -398,7 +416,6 @@ def train(args):
train_pyreader.reset()
break
if __name__ == '__main__':
print_arguments(args)
if args.do_test:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册