提交 d89c6b43 编写于 作者: L LDOUBLEV

add profile

上级 3b2e50a1
......@@ -42,6 +42,13 @@ class ArgsParser(ArgumentParser):
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
self.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
......@@ -151,7 +158,8 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None):
vdl_writer=None,
profiler_options=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
log_smooth_window = config['Global']['log_smooth_window']
......@@ -208,6 +216,7 @@ def train(config,
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - batch_start
if idx >= max_iter:
break
......@@ -392,6 +401,7 @@ def eval(model,
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
......@@ -431,4 +441,4 @@ def preprocess(is_train=False):
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer
return config, device, logger, vdl_writer, profiler_options
......@@ -41,7 +41,7 @@ import tools.program as program
dist.get_world_size()
def main(config, device, logger, vdl_writer):
def main(config, device, logger, vdl_writer, profiler_options):
# init dist environment
if config['Global']['distributed']:
dist.init_parallel_env()
......@@ -105,7 +105,8 @@ def main(config, device, logger, vdl_writer):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
eval_class, pre_best_model_dict, logger, vdl_writer,
profiler_options)
def test_reader(config, device, logger):
......@@ -127,6 +128,8 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer)
config, device, logger, vdl_writer, profiler_options = program.preprocess(
is_train=True)
main(config, device, logger, vdl_writer, profiler_options)
# test_reader(config, device, logger)
# test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册