diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index 96fe3bf1777f681de4625095dff7970ccfb7a590..39392c82c5ae58ab45660d40d83f9b37e78d60e0 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -123,11 +123,18 @@ def do_train(args): ce_time = [] ce_ppl = [] step_idx = 0 + + #NOTE: used for benchmark + total_batch_num = 0 + # train loop for pass_id in range(args.epoch): pass_start_time = time.time() batch_id = 0 for input_data in train_loader(): + if args.max_iter and total_batch_num == args.max_iter: #NOTE: used for benchmark + return + batch_start = time.time() (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight) = input_data @@ -186,6 +193,7 @@ def do_train(args): os.path.join(model_dir, "transformer")) batch_id += 1 + total_batch_num = total_batch_num + 1 step_idx += 1 time_consumed = time.time() - pass_start_time