未验证 提交 a6a4a300 编写于 作者: Z zhengya01 提交者: GitHub

add ce for transformer (#4359)

上级 5bc12f64
......@@ -120,6 +120,8 @@ def do_train(args):
args.label_smooth_eps * np.log(args.label_smooth_eps /
(args.trg_vocab_size - 1) + 1e-20))
ce_time = []
ce_ppl = []
step_idx = 0
# train loop
for pass_id in range(args.epoch):
......@@ -165,6 +167,7 @@ def do_train(args):
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
ce_ppl.append(np.exp([min(total_avg_cost, 100)]))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0 and (
......@@ -186,6 +189,7 @@ def do_train(args):
step_idx += 1
time_consumed = time.time() - pass_start_time
ce_time.append(time_consumed)
if args.save_model:
model_dir = os.path.join(args.save_model, "step_final")
......@@ -196,6 +200,17 @@ def do_train(args):
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(model_dir, "transformer"))
if args.enable_ce:
_ppl = 0
_time = 0
try:
_time = ce_time[-1]
_ppl = ce_ppl[-1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (trainer_count, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (trainer_count, _ppl))
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册