From 11ed966b14398d01b68e62f3c8cabd9d8b842fe6 Mon Sep 17 00:00:00 2001 From: hysunflower <52739577+hysunflower@users.noreply.github.com> Date: Tue, 26 Nov 2019 16:54:41 +0800 Subject: [PATCH] add_for_transformer_model (#3980) --- PaddleNLP/PaddleMT/transformer/train.py | 12 ++++++++++++ PaddleNLP/PaddleMT/transformer/utils/configure.py | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/PaddleNLP/PaddleMT/transformer/train.py b/PaddleNLP/PaddleMT/transformer/train.py index 48b4847f..b3555be6 100644 --- a/PaddleNLP/PaddleMT/transformer/train.py +++ b/PaddleNLP/PaddleMT/transformer/train.py @@ -21,6 +21,7 @@ import time import numpy as np import paddle import paddle.fluid as fluid +from paddle.fluid import profiler import utils.dist_utils as dist_utils from utils.input_field import InputField @@ -250,12 +251,15 @@ def do_train(args): # start training step_idx = 0 + total_batch_num = 0 # this is for benchmark for pass_id in range(args.epoch): pass_start_time = time.time() input_field.loader.start() batch_id = 0 while True: + if args.max_iter and total_batch_num == args.max_iter: # this for benchmark + return try: outs = exe.run(compiled_train_prog, fetch_list=[sum_cost.name, token_num.name] @@ -299,6 +303,14 @@ def do_train(args): batch_id += 1 step_idx += 1 + total_batch_num = total_batch_num + 1 # this is for benchmark + + # profiler tools for benchmark + if args.is_profiler and pass_id == 0 and batch_id == args.print_step: + profiler.start_profiler("All") + elif args.is_profiler and pass_id == 0 and batch_id == args.print_step + 5: + profiler.stop_profiler("total", args.profiler_path) + return except fluid.core.EOFException: input_field.loader.reset() diff --git a/PaddleNLP/PaddleMT/transformer/utils/configure.py b/PaddleNLP/PaddleMT/transformer/utils/configure.py index 2ea9fd96..67e60128 100644 --- a/PaddleNLP/PaddleMT/transformer/utils/configure.py +++ b/PaddleNLP/PaddleMT/transformer/utils/configure.py @@ -198,6 +198,11 @@ class PDConfig(object): self.default_g.add_arg("do_save_inference_model", bool, False, "Whether to perform model saving for inference.") + # NOTE: args for profiler + self.default_g.add_arg("is_profiler", int, 0, "the switch of profiler tools. (used for benchmark)") + self.default_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)") + self.default_g.add_arg("max_iter", int, 0, "the max train batch num.(used for benchmark)") + self.parser = parser if json_file != "": -- GitLab