提交 eed924a6 编写于 作者: H hysunflower 提交者: Jinhua Liang

add_profiler_nextvlad (#4058)

上级 9aefcdc8
...@@ -104,6 +104,17 @@ def parse_args(): ...@@ -104,6 +104,17 @@ def parse_args():
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help='If set True, enable continuous evaluation job.') help='If set True, enable continuous evaluation job.')
# NOTE: args for profiler, used for benchmark
parser.add_argument(
'--profiler_path',
type=str,
default='./',
help='the path to store profiler output file. used for benchmark.')
parser.add_argument(
'--is_profiler',
type=int,
default=0,
help='the switch profiler. used for benchmark.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -236,7 +247,9 @@ def train(args): ...@@ -236,7 +247,9 @@ def train(args):
compiled_test_prog=compiled_valid_prog, #test_exe=valid_exe, compiled_test_prog=compiled_valid_prog, #test_exe=valid_exe,
test_dataloader=valid_dataloader, test_dataloader=valid_dataloader,
test_fetch_list=valid_fetch_list, test_fetch_list=valid_fetch_list,
test_metrics=valid_metrics) test_metrics=valid_metrics,
is_profiler=args.is_profiler,
profiler_path=args.profiler_path)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -18,6 +18,7 @@ import time ...@@ -18,6 +18,7 @@ import time
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import logging import logging
import shutil import shutil
...@@ -76,7 +77,8 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader ...@@ -76,7 +77,8 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
log_interval = 0, valid_interval = 0, save_dir = './', \ log_interval = 0, valid_interval = 0, save_dir = './', \
save_model_name = 'model', fix_random_seed = False, \ save_model_name = 'model', fix_random_seed = False, \
compiled_test_prog = None, test_dataloader = None, \ compiled_test_prog = None, test_dataloader = None, \
test_fetch_list = None, test_metrics = None): test_fetch_list = None, test_metrics = None, \
is_profiler = None, profiler_path = None):
if not train_dataloader: if not train_dataloader:
logger.error("[TRAIN] get dataloader failed.") logger.error("[TRAIN] get dataloader failed.")
epoch_periods = [] epoch_periods = []
...@@ -98,6 +100,13 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader ...@@ -98,6 +100,13 @@ def train_with_dataloader(exe, train_prog, compiled_train_prog, train_dataloader
train_metrics.calculate_and_log_out(train_outs, \ train_metrics.calculate_and_log_out(train_outs, \
info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter)) info = '[TRAIN] Epoch {}, iter {} '.format(epoch, train_iter))
train_iter += 1 train_iter += 1
# NOTE: profiler tools, used for benchmark
if is_profiler and epoch == 0 and train_iter == log_interval:
profiler.start_profiler("All")
elif is_profiler and epoch == 0 and train_iter == log_interval + 5:
profiler.stop_profiler("total", profiler_path)
return
if len(epoch_periods) < 1: if len(epoch_periods) < 1:
logger.info( logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册