From 406187e0fe37672cd57cff8033cbbccd95ea677d Mon Sep 17 00:00:00 2001 From: hysunflower <52739577+hysunflower@users.noreply.github.com> Date: Thu, 14 May 2020 21:14:58 +0800 Subject: [PATCH] add_max_iter_for_tarnsformer (#4618) --- dygraph/transformer/train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index 96fe3bf1..39392c82 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -123,11 +123,18 @@ def do_train(args): ce_time = [] ce_ppl = [] step_idx = 0 + + #NOTE: used for benchmark + total_batch_num = 0 + # train loop for pass_id in range(args.epoch): pass_start_time = time.time() batch_id = 0 for input_data in train_loader(): + if args.max_iter and total_batch_num == args.max_iter: #NOTE: used for benchmark + return + batch_start = time.time() (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight) = input_data @@ -186,6 +193,7 @@ def do_train(args): os.path.join(model_dir, "transformer")) batch_id += 1 + total_batch_num = total_batch_num + 1 step_idx += 1 time_consumed = time.time() - pass_start_time -- GitLab