diff --git a/ppcls/static/program.py b/ppcls/static/program.py index 71f630f7b7913988101510c1766e82da36a29932..e6022bbde4529b353db6102e5ac93f798a1cd196 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -38,7 +38,7 @@ from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_lr_scheduler from ppcls.utils.misc import AverageMeter -from ppcls.utils import logger +from ppcls.utils import logger, profiler def create_feeds(image_shape, use_mix=None, dtype="float32"): @@ -326,7 +326,8 @@ def run(dataloader, mode='train', config=None, vdl_writer=None, - lr_scheduler=None): + lr_scheduler=None, + profiler_options=None): """ Feed data to the model and fetch the measures and loss @@ -382,6 +383,8 @@ def run(dataloader, metric_dict['reader_time'].update(time.time() - tic) + profiler.add_profiler_step(profiler_options) + if use_dali: batch_size = batch[0]["data"].shape()[0] feed_dict = batch[0] diff --git a/ppcls/static/train.py b/ppcls/static/train.py index d894ce8ca2eb4853db30441c4304964751e91e71..a3aa0b591ce2db7d1066f1fada521e3a91cfd239 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -43,6 +43,13 @@ def parse_args(): type=str, default='configs/ResNet/ResNet50.yaml', help='config file path') + parser.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) parser.add_argument( '-o', '--override', @@ -166,7 +173,7 @@ def main(args): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_feeds, train_fetchs, epoch_id, 'train', config, vdl_writer, - lr_scheduler) + lr_scheduler, args.profiler_options) # 2. evaate with eval dataset if global_config["eval_during_train"] and epoch_id % global_config[ "eval_interval"] == 0: