diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py index 08b9e1ae15b848e80d8a671893be6e8babc355f7..98515a8329fba8508b01accbbed940ef2df65842 100644 --- a/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py @@ -203,8 +203,8 @@ def main(): word_count = 0.0 batch_count = 0.0 batch_times = [] - batch_start_time = time.time() for batch_id, batch in enumerate(train_data_iter): + batch_start_time = time.time() kl_w = min(1.0, kl_w + anneal_r) kl_weight = kl_w input_data_feed, src_word_num, dec_word_sum = prepare_input( @@ -280,6 +280,18 @@ def main(): print('\nbest testing nll: %.4f, best testing ppl %.4f\n' % (best_nll, best_ppl)) + if args.enable_ce: + card_num = get_cards() + _ppl = 0 + _time = 0 + try: + _time = ce_time[-1] + _ppl = ce_ppl[-1] + except: + print("ce info error") + print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time)) + print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl)) + with profile_context(args.profile): train()