From 0dfd43a2757e00571dc84ff95e9d18bc62fcf8c3 Mon Sep 17 00:00:00 2001 From: Xing Wu <1160386409@qq.com> Date: Thu, 12 Dec 2019 13:48:50 +0800 Subject: [PATCH] Fix vae batch time (#4065) * fix step time in vae * fix step time in vae --- .../PaddleTextGEN/variational_seq2seq/train.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py index 08b9e1ae..98515a83 100644 --- a/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py @@ -203,8 +203,8 @@ def main(): word_count = 0.0 batch_count = 0.0 batch_times = [] - batch_start_time = time.time() for batch_id, batch in enumerate(train_data_iter): + batch_start_time = time.time() kl_w = min(1.0, kl_w + anneal_r) kl_weight = kl_w input_data_feed, src_word_num, dec_word_sum = prepare_input( @@ -280,6 +280,18 @@ def main(): print('\nbest testing nll: %.4f, best testing ppl %.4f\n' % (best_nll, best_ppl)) + if args.enable_ce: + card_num = get_cards() + _ppl = 0 + _time = 0 + try: + _time = ce_time[-1] + _ppl = ce_ppl[-1] + except: + print("ce info error") + print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time)) + print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl)) + with profile_context(args.profile): train() -- GitLab