未验证 提交 406187e0 编写于 作者: H hysunflower 提交者: GitHub

add_max_iter_for_tarnsformer (#4618)

上级 f0c08fc0
...@@ -123,11 +123,18 @@ def do_train(args): ...@@ -123,11 +123,18 @@ def do_train(args):
ce_time = [] ce_time = []
ce_ppl = [] ce_ppl = []
step_idx = 0 step_idx = 0
#NOTE: used for benchmark
total_batch_num = 0
# train loop # train loop
for pass_id in range(args.epoch): for pass_id in range(args.epoch):
pass_start_time = time.time() pass_start_time = time.time()
batch_id = 0 batch_id = 0
for input_data in train_loader(): for input_data in train_loader():
if args.max_iter and total_batch_num == args.max_iter: #NOTE: used for benchmark
return
batch_start = time.time()
(src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
lbl_weight) = input_data lbl_weight) = input_data
...@@ -186,6 +193,7 @@ def do_train(args): ...@@ -186,6 +193,7 @@ def do_train(args):
os.path.join(model_dir, "transformer")) os.path.join(model_dir, "transformer"))
batch_id += 1 batch_id += 1
total_batch_num = total_batch_num + 1
step_idx += 1 step_idx += 1
time_consumed = time.time() - pass_start_time time_consumed = time.time() - pass_start_time
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册