提交 9f2ab06e 编写于 作者: D dongshuilong

add profiler

上级 22c0a53c
...@@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function ...@@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function
import time import time
import paddle import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info from ppcls.engine.train.utils import update_loss, update_metric, log_info
from ppcls.utils import profiler
def train_epoch(trainer, epoch_id, print_batch_step): def train_epoch(trainer, epoch_id, print_batch_step):
...@@ -26,6 +27,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): ...@@ -26,6 +27,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
for iter_id, batch in enumerate(train_dataloader): for iter_id, batch in enumerate(train_dataloader):
if iter_id >= trainer.max_iter: if iter_id >= trainer.max_iter:
break break
profiler.add_profiler_step(trainer.config["profiler_options"])
if iter_id == 5: if iter_id == 5:
for key in trainer.time_info: for key in trainer.time_info:
trainer.time_info[key].reset() trainer.time_info[key].reset()
......
...@@ -199,5 +199,12 @@ def parse_args(): ...@@ -199,5 +199,12 @@ def parse_args():
action='append', action='append',
default=[], default=[],
help='config options to be overridden') help='config options to be overridden')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -27,5 +27,6 @@ if __name__ == "__main__": ...@@ -27,5 +27,6 @@ if __name__ == "__main__":
args = config.parse_args() args = config.parse_args()
config = config.get_config( config = config.get_config(
args.config, overrides=args.override, show=False) args.config, overrides=args.override, show=False)
config.profiler_options = args.profiler_options
engine = Engine(config, mode="train") engine = Engine(config, mode="train")
engine.train() engine.train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册