From eed924a67189ac9fa4f752a29fc9d0228c36b62d Mon Sep 17 00:00:00 2001 From: hysunflower <52739577+hysunflower@users.noreply.github.com> Date: Wed, 11 Dec 2019 19:39:55 +0800 Subject: [PATCH] add_profiler_nextvlad (#4058) --- PaddleCV/PaddleVideo/train.py | 15 ++++++++++++++- PaddleCV/PaddleVideo/utils/train_utils.py | 11 ++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/PaddleCV/PaddleVideo/train.py b/PaddleCV/PaddleVideo/train.py index 467523d8..4adac343 100644 --- a/PaddleCV/PaddleVideo/train.py +++ b/PaddleCV/PaddleVideo/train.py @@ -104,6 +104,17 @@ def parse_args(): type=ast.literal_eval, default=False, help='If set True, enable continuous evaluation job.') + # NOTE: args for profiler, used for benchmark + parser.add_argument( + '--profiler_path', + type=str, + default='./', + help='the path to store profiler output file. used for benchmark.') + parser.add_argument( + '--is_profiler', + type=int, + default=0, + help='the switch profiler. used for benchmark.') args = parser.parse_args() return args @@ -236,7 +247,9 @@ def train(args): compiled_test_prog=compiled_valid_prog, #test_exe=valid_exe, test_dataloader=valid_dataloader, test_fetch_list=valid_fetch_list, - test_metrics=valid_metrics) + test_metrics=valid_metrics, + is_profiler=args.is_profiler, + profiler_path=args.profiler_path) if __name__ == "__main__": diff --git a/PaddleCV/PaddleVideo/utils/train_utils.py b/PaddleCV/PaddleVideo/utils/train_utils.py index 4168abbb..f7e48918 100644 --- a/PaddleCV/PaddleVideo/utils/train_utils.py +++ b/PaddleCV/PaddleVideo/utils/train_utils.py @@ -18,6 +18,7 @@ import time import numpy as np import paddle import paddle.fluid as fluid +from paddle.fluid import profiler import logging import shutil @@ -76,7 +77,8 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader log_interval = 0, valid_interval = 0, save_dir = './', \ save_model_name = 'model', fix_random_seed = False, \ compiled_test_prog = None, test_dataloader = None, \ - test_fetch_list = None, test_metrics = None): + test_fetch_list = None, test_metrics = None, \ + is_profiler = None, profiler_path = None): if not train_dataloader: logger.error("[TRAIN] get dataloader failed.") epoch_periods = [] @@ -98,6 +100,13 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader train_metrics.calculate_and_log_out(train_outs, \ info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter)) train_iter += 1 + + # NOTE: profiler tools, used for benchmark + if is_profiler and epoch == 0 and train_iter == log_interval: + profiler.start_profiler("All") + elif is_profiler and epoch == 0 and train_iter == log_interval + 5: + profiler.stop_profiler("total", profiler_path) + return if len(epoch_periods) < 1: logger.info( -- GitLab