未验证 提交 406187e0 编写于 作者: H hysunflower 提交者: GitHub

add_max_iter_for_tarnsformer (#4618)

上级 f0c08fc0
......@@ -123,11 +123,18 @@ def do_train(args):
ce_time = []
ce_ppl = []
step_idx = 0
#NOTE: used for benchmark
total_batch_num = 0
# train loop
for pass_id in range(args.epoch):
pass_start_time = time.time()
batch_id = 0
for input_data in train_loader():
if args.max_iter and total_batch_num == args.max_iter: #NOTE: used for benchmark
return
batch_start = time.time()
(src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
lbl_weight) = input_data
......@@ -186,6 +193,7 @@ def do_train(args):
os.path.join(model_dir, "transformer"))
batch_id += 1
total_batch_num = total_batch_num + 1
step_idx += 1
time_consumed = time.time() - pass_start_time
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册