From f0a09c3a8f5bfc324db36e420aa04b924a570974 Mon Sep 17 00:00:00 2001 From: zhengya01 <43601548+zhengya01@users.noreply.github.com> Date: Thu, 20 Feb 2020 21:48:20 +0800 Subject: [PATCH] add ce for dygraph_lac (#4327) --- dygraph/lac/train.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/dygraph/lac/train.py b/dygraph/lac/train.py index 00f294ea..401c2551 100755 --- a/dygraph/lac/train.py +++ b/dygraph/lac/train.py @@ -44,10 +44,11 @@ def do_train(args): with fluid.dygraph.guard(place): if args.use_data_parallel: strategy = fluid.dygraph.parallel.prepare_context() - #fluid.default_startup_program().random_seed = 102 - #fluid.default_main_program().random_seed = 102 - #np.random.seed(102) - #random.seed(102) + if args.enable_ce: + fluid.default_startup_program().random_seed = 102 + fluid.default_main_program().random_seed = 102 + np.random.seed(102) + random.seed(102) train_loader = reader.create_dataloader( args, file_name=args.train_data, @@ -102,6 +103,8 @@ def do_train(args): (precision, recall, f1, end_time - start_time)) model.train() + ce_time = [] + ce_infor = [] for epoch_id in range(args.epoch): for batch in train_loader(): words, targets, length = batch @@ -129,6 +132,8 @@ def do_train(args): print("[train] step = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f" % ( step, avg_cost, precision, recall, f1_score, end_time - start_time)) + ce_time.append(end_time - start_time) + ce_infor.append([precision, recall, f1_score]) if step % args.validation_steps == 0: test_process(test_loader, chunk_evaluator) @@ -139,7 +144,24 @@ def do_train(args): paddle.fluid.save_dygraph(model.state_dict(), save_path) step += 1 - + if args.enable_ce and fluid.dygraph.parallel.Env().local_rank == 0: + card_num = fluid.core.get_cuda_device_count() + _p = 0 + _r = 0 + _f1 = 0 + _time = 0 + try: + _time = ce_time[-1] + _p = ce_infor[-1][0] + _r = ce_infor[-1][1] + _f1 = ce_infor[-1][2] + except: + print("ce info error") + print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time)) + print("kpis\ttrain_p_card%s\t%f" % (card_num, _p)) + print("kpis\ttrain_r_card%s\t%f" % (card_num, _r)) + print("kpis\ttrain_f1_card%s\t%f" % (card_num, _f1)) + if __name__ == "__main__": # 参数控制可以根据需求使用argparse,yaml或者json -- GitLab