diff --git a/PaddleNLP/language_model/args.py b/PaddleNLP/language_model/args.py index 8014bb521ec4f276603c45702d4090048235646c..964686890532314d4b179bc822419ed2adbdb2ed 100644 --- a/PaddleNLP/language_model/args.py +++ b/PaddleNLP/language_model/args.py @@ -80,5 +80,8 @@ def parse_args(): parser.add_argument('--enable_ce', action='store_true') parser.add_argument('--batch_size', type=int, default=0, help='batch size') parser.add_argument('--max_epoch', type=int, default=0, help='max epoch') + + # NOTE: args for profiler, used for benchmark + parser.add_argument('--profiler_path', type=str, default='/tmp/paddingrnn.profile', help='the profiler output file path. used for benchmark') args = parser.parse_args() return args diff --git a/PaddleNLP/language_model/train.py b/PaddleNLP/language_model/train.py index f26e437dbb60d3f56dd1cfbb9e892a89d417a9f6..10562712f807672e2016c85419517cefac0ffc50 100644 --- a/PaddleNLP/language_model/train.py +++ b/PaddleNLP/language_model/train.py @@ -25,6 +25,7 @@ import contextlib from distutils.dir_util import mkpath import paddle import paddle.fluid as fluid +from paddle.fluid import profiler import paddle.fluid.framework as framework import paddle.fluid.profiler as profiler from paddle.fluid.executor import Executor @@ -50,9 +51,9 @@ SEED = 123 @contextlib.contextmanager -def profile_context(profile=True): +def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'): if profile: - with profiler.profiler('All', 'total', '/tmp/paddingrnn.profile'): + with profiler.profiler('All', 'total', profiler_path): yield else: yield @@ -318,6 +319,12 @@ def main(): print( "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" % (epoch_id, batch_id, batch_time, ppl[0], lr[0])) + + # profiler tools for benchmark + if args.profile and batch_id == log_interval: + profiler.reset_profiler() + elif args.profile and batch_id == (log_interval + 5): + break ppl = np.exp(total_loss / iters) return ppl @@ -371,6 +378,11 @@ def main(): % (epoch_id, batch_id, batch_time, ppl[0], lr[0])) batch_id += 1 + # profiler tools for benchmark + if args.profile and batch_id == log_interval: + profiler.reset_profiler() + elif args.profile and batch_id == (log_interval + 5): + break except fluid.core.EOFException: dataloader.reset() @@ -455,7 +467,7 @@ def main(): fluid.save(main_program, save_model_dir) print("Saved model to: %s.\n" % save_model_dir) - with profile_context(args.profile): + with profile_context(args.profile, args.profiler_path): train() test_ppl = eval(test_data)