From 097bad6494090e0a13b871042d6b4982590f40f5 Mon Sep 17 00:00:00 2001 From: hysunflower <52739577+hysunflower@users.noreply.github.com> Date: Wed, 20 Nov 2019 16:46:14 +0800 Subject: [PATCH] update the profiler of seq2seq (#3959) * update the profiler of seq2seq * modify profiler args --- PaddleNLP/PaddleTextGEN/seq2seq/args.py | 7 +------ PaddleNLP/PaddleTextGEN/seq2seq/train.py | 13 ++++++------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/args.py b/PaddleNLP/PaddleTextGEN/seq2seq/args.py index eba46cb3..99f21b08 100644 --- a/PaddleNLP/PaddleTextGEN/seq2seq/args.py +++ b/PaddleNLP/PaddleTextGEN/seq2seq/args.py @@ -123,15 +123,10 @@ def parse_args(): parser.add_argument( "--profile", action='store_true', help="Whether enable the profile.") # NOTE: profiler args, used for benchmark - parser.add_argument( - "--is_profiler", - type=int, - default=0, - help="The switch of profiler tools. (used for benchmark)") parser.add_argument( "--profiler_path", type=str, - default='./', + default='./seq2seq.profile', help="the profiler output file path. (used for benchmark)") args = parser.parse_args() return args diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/train.py b/PaddleNLP/PaddleTextGEN/seq2seq/train.py index cbf6aa4f..e44d9a47 100644 --- a/PaddleNLP/PaddleTextGEN/seq2seq/train.py +++ b/PaddleNLP/PaddleTextGEN/seq2seq/train.py @@ -47,9 +47,9 @@ import pickle @contextlib.contextmanager -def profile_context(profile=True): +def profile_context(profile=True, profiler_path='./seq2seq.profile'): if profile: - with profiler.profiler('All', 'total', 'seq2seq.profile'): + with profiler.profiler('All', 'total', profiler_path): yield else: yield @@ -216,10 +216,9 @@ def main(): word_count = 0.0 # profiler tools - if args.is_profiler and epoch_id == 0 and batch_id == 100: - profiler.start_profiler("All") - elif args.is_profiler and epoch_id == 0 and batch_id == 105: - profiler.stop_profiler("total", args.profiler_path) + if args.profile and epoch_id == 0 and batch_id == 100: + profiler.reset_profiler() + elif args.profile and epoch_id == 0 and batch_id == 105: return end_time = time.time() @@ -252,7 +251,7 @@ def main(): print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time)) print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl)) - with profile_context(args.profile): + with profile_context(args.profile, args.profiler_path): train() -- GitLab