提交 0dfd43a2 编写于 作者: 1024的传说's avatar 1024的传说 提交者: Guo Sheng

Fix vae batch time (#4065)

* fix step time in vae

* fix step time in vae
上级 eed924a6
...@@ -203,8 +203,8 @@ def main(): ...@@ -203,8 +203,8 @@ def main():
word_count = 0.0 word_count = 0.0
batch_count = 0.0 batch_count = 0.0
batch_times = [] batch_times = []
batch_start_time = time.time()
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time()
kl_w = min(1.0, kl_w + anneal_r) kl_w = min(1.0, kl_w + anneal_r)
kl_weight = kl_w kl_weight = kl_w
input_data_feed, src_word_num, dec_word_sum = prepare_input( input_data_feed, src_word_num, dec_word_sum = prepare_input(
...@@ -280,6 +280,18 @@ def main(): ...@@ -280,6 +280,18 @@ def main():
print('\nbest testing nll: %.4f, best testing ppl %.4f\n' % print('\nbest testing nll: %.4f, best testing ppl %.4f\n' %
(best_nll, best_ppl)) (best_nll, best_ppl))
if args.enable_ce:
card_num = get_cards()
_ppl = 0
_time = 0
try:
_time = ce_time[-1]
_ppl = ce_ppl[-1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl))
with profile_context(args.profile): with profile_context(args.profile):
train() train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册