提交 eed924a6 编写于 作者: H hysunflower 提交者: Jinhua Liang

add_profiler_nextvlad (#4058)

上级 9aefcdc8
......@@ -104,6 +104,17 @@ def parse_args():
type=ast.literal_eval,
default=False,
help='If set True, enable continuous evaluation job.')
# NOTE: args for profiler, used for benchmark
parser.add_argument(
'--profiler_path',
type=str,
default='./',
help='the path to store profiler output file. used for benchmark.')
parser.add_argument(
'--is_profiler',
type=int,
default=0,
help='the switch profiler. used for benchmark.')
args = parser.parse_args()
return args
......@@ -236,7 +247,9 @@ def train(args):
compiled_test_prog=compiled_valid_prog, #test_exe=valid_exe,
test_dataloader=valid_dataloader,
test_fetch_list=valid_fetch_list,
test_metrics=valid_metrics)
test_metrics=valid_metrics,
is_profiler=args.is_profiler,
profiler_path=args.profiler_path)
if __name__ == "__main__":
......
......@@ -18,6 +18,7 @@ import time
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import logging
import shutil
......@@ -76,7 +77,8 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
log_interval = 0, valid_interval = 0, save_dir = './', \
save_model_name = 'model', fix_random_seed = False, \
compiled_test_prog = None, test_dataloader = None, \
test_fetch_list = None, test_metrics = None):
test_fetch_list = None, test_metrics = None, \
is_profiler = None, profiler_path = None):
if not train_dataloader:
logger.error("[TRAIN] get dataloader failed.")
epoch_periods = []
......@@ -98,6 +100,13 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
train_metrics.calculate_and_log_out(train_outs, \
info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter))
train_iter += 1
# NOTE: profiler tools, used for benchmark
if is_profiler and epoch == 0 and train_iter == log_interval:
profiler.start_profiler("All")
elif is_profiler and epoch == 0 and train_iter == log_interval + 5:
profiler.stop_profiler("total", profiler_path)
return
if len(epoch_periods) < 1:
logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册