未验证 提交 00455839 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the profiler back for static training. (#1094)

上级 274f8190
......@@ -38,7 +38,7 @@ from ppcls.optimizer import build_optimizer
from ppcls.optimizer import build_lr_scheduler
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils import logger, profiler
def create_feeds(image_shape, use_mix=None, dtype="float32"):
......@@ -326,7 +326,8 @@ def run(dataloader,
mode='train',
config=None,
vdl_writer=None,
lr_scheduler=None):
lr_scheduler=None,
profiler_options=None):
"""
Feed data to the model and fetch the measures and loss
......@@ -382,6 +383,8 @@ def run(dataloader,
metric_dict['reader_time'].update(time.time() - tic)
profiler.add_profiler_step(profiler_options)
if use_dali:
batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0]
......
......@@ -43,6 +43,13 @@ def parse_args():
type=str,
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
parser.add_argument(
'-o',
'--override',
......@@ -166,7 +173,7 @@ def main(args):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler)
lr_scheduler, args.profiler_options)
# 2. evaate with eval dataset
if global_config["eval_during_train"] and epoch_id % global_config[
"eval_interval"] == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册