diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/args.py b/PaddleNLP/PaddleTextGEN/seq2seq/args.py index ee056e33597651f9e166e4d6399c89bfc36598f7..eba46cb3beed3b63ec73a9c58554a05993b6a08c 100644 --- a/PaddleNLP/PaddleTextGEN/seq2seq/args.py +++ b/PaddleNLP/PaddleTextGEN/seq2seq/args.py @@ -122,6 +122,16 @@ def parse_args(): parser.add_argument( "--profile", action='store_true', help="Whether enable the profile.") - + # NOTE: profiler args, used for benchmark + parser.add_argument( + "--is_profiler", + type=int, + default=0, + help="The switch of profiler tools. (used for benchmark)") + parser.add_argument( + "--profiler_path", + type=str, + default='./', + help="the profiler output file path. (used for benchmark)") args = parser.parse_args() return args diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/train.py b/PaddleNLP/PaddleTextGEN/seq2seq/train.py index 51d4d29eac141f35676fe92ef2713c63f86d4aae..cbf6aa4f39af5322bc58dba7a60251eca83d02ff 100644 --- a/PaddleNLP/PaddleTextGEN/seq2seq/train.py +++ b/PaddleNLP/PaddleTextGEN/seq2seq/train.py @@ -27,6 +27,7 @@ import contextlib import paddle import paddle.fluid as fluid +from paddle.fluid import profiler import paddle.fluid.framework as framework import paddle.fluid.profiler as profiler from paddle.fluid.executor import Executor @@ -213,6 +214,13 @@ def main(): ce_ppl.append(np.exp(total_loss / word_count)) total_loss = 0.0 word_count = 0.0 + + # profiler tools + if args.is_profiler and epoch_id == 0 and batch_id == 100: + profiler.start_profiler("All") + elif args.is_profiler and epoch_id == 0 and batch_id == 105: + profiler.stop_profiler("total", args.profiler_path) + return end_time = time.time() epoch_time = end_time - start_time