未验证 提交 f0a09c3a 编写于 作者: Z zhengya01 提交者: GitHub

add ce for dygraph_lac (#4327)

上级 837489ed
...@@ -44,10 +44,11 @@ def do_train(args): ...@@ -44,10 +44,11 @@ def do_train(args):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context() strategy = fluid.dygraph.parallel.prepare_context()
#fluid.default_startup_program().random_seed = 102 if args.enable_ce:
#fluid.default_main_program().random_seed = 102 fluid.default_startup_program().random_seed = 102
#np.random.seed(102) fluid.default_main_program().random_seed = 102
#random.seed(102) np.random.seed(102)
random.seed(102)
train_loader = reader.create_dataloader( train_loader = reader.create_dataloader(
args, args,
file_name=args.train_data, file_name=args.train_data,
...@@ -102,6 +103,8 @@ def do_train(args): ...@@ -102,6 +103,8 @@ def do_train(args):
(precision, recall, f1, end_time - start_time)) (precision, recall, f1, end_time - start_time))
model.train() model.train()
ce_time = []
ce_infor = []
for epoch_id in range(args.epoch): for epoch_id in range(args.epoch):
for batch in train_loader(): for batch in train_loader():
words, targets, length = batch words, targets, length = batch
...@@ -129,6 +132,8 @@ def do_train(args): ...@@ -129,6 +132,8 @@ def do_train(args):
print("[train] step = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f" % ( print("[train] step = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f" % (
step, avg_cost, precision, recall, f1_score, end_time - start_time)) step, avg_cost, precision, recall, f1_score, end_time - start_time))
ce_time.append(end_time - start_time)
ce_infor.append([precision, recall, f1_score])
if step % args.validation_steps == 0: if step % args.validation_steps == 0:
test_process(test_loader, chunk_evaluator) test_process(test_loader, chunk_evaluator)
...@@ -139,7 +144,24 @@ def do_train(args): ...@@ -139,7 +144,24 @@ def do_train(args):
paddle.fluid.save_dygraph(model.state_dict(), save_path) paddle.fluid.save_dygraph(model.state_dict(), save_path)
step += 1 step += 1
if args.enable_ce and fluid.dygraph.parallel.Env().local_rank == 0:
card_num = fluid.core.get_cuda_device_count()
_p = 0
_r = 0
_f1 = 0
_time = 0
try:
_time = ce_time[-1]
_p = ce_infor[-1][0]
_r = ce_infor[-1][1]
_f1 = ce_infor[-1][2]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_p_card%s\t%f" % (card_num, _p))
print("kpis\ttrain_r_card%s\t%f" % (card_num, _r))
print("kpis\ttrain_f1_card%s\t%f" % (card_num, _f1))
if __name__ == "__main__": if __name__ == "__main__":
# 参数控制可以根据需求使用argparse,yaml或者json # 参数控制可以根据需求使用argparse,yaml或者json
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册