未验证 提交 00455839 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the profiler back for static training. (#1094)

上级 274f8190
...@@ -38,7 +38,7 @@ from ppcls.optimizer import build_optimizer ...@@ -38,7 +38,7 @@ from ppcls.optimizer import build_optimizer
from ppcls.optimizer import build_lr_scheduler from ppcls.optimizer import build_lr_scheduler
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger, profiler
def create_feeds(image_shape, use_mix=None, dtype="float32"): def create_feeds(image_shape, use_mix=None, dtype="float32"):
...@@ -326,7 +326,8 @@ def run(dataloader, ...@@ -326,7 +326,8 @@ def run(dataloader,
mode='train', mode='train',
config=None, config=None,
vdl_writer=None, vdl_writer=None,
lr_scheduler=None): lr_scheduler=None,
profiler_options=None):
""" """
Feed data to the model and fetch the measures and loss Feed data to the model and fetch the measures and loss
...@@ -382,6 +383,8 @@ def run(dataloader, ...@@ -382,6 +383,8 @@ def run(dataloader,
metric_dict['reader_time'].update(time.time() - tic) metric_dict['reader_time'].update(time.time() - tic)
profiler.add_profiler_step(profiler_options)
if use_dali: if use_dali:
batch_size = batch[0]["data"].shape()[0] batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0] feed_dict = batch[0]
......
...@@ -43,6 +43,13 @@ def parse_args(): ...@@ -43,6 +43,13 @@ def parse_args():
type=str, type=str,
default='configs/ResNet/ResNet50.yaml', default='configs/ResNet/ResNet50.yaml',
help='config file path') help='config file path')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
parser.add_argument( parser.add_argument(
'-o', '-o',
'--override', '--override',
...@@ -166,7 +173,7 @@ def main(args): ...@@ -166,7 +173,7 @@ def main(args):
# 1. train with train dataset # 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_feeds, program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer, train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler) lr_scheduler, args.profiler_options)
# 2. evaate with eval dataset # 2. evaate with eval dataset
if global_config["eval_during_train"] and epoch_id % global_config[ if global_config["eval_during_train"] and epoch_id % global_config[
"eval_interval"] == 0: "eval_interval"] == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册