提交 c6f6a54a 编写于 作者: H hysunflower 提交者: Jinhua Liang

add profiler for image_classification (#3998)

上级 62e824db
......@@ -178,6 +178,8 @@ def train(args):
compiled_train_prog = best_strategy_compiled(args, train_prog,
train_fetch_vars[0], exe)
#NOTE: this for benchmark
total_batch_num = 0
for pass_id in range(args.num_epochs):
if num_trainers > 1 and not args.use_dali:
imagenet_reader.set_shuffle_seed(pass_id + (
......@@ -192,6 +194,9 @@ def train(args):
t1 = time.time()
for batch in train_iter:
#NOTE: this is for benchmark
if args.max_iter and total_batch_num == args.max_iter:
return
train_batch_metrics = exe.run(compiled_train_prog,
feed=batch,
fetch_list=train_fetch_list)
......@@ -207,6 +212,13 @@ def train(args):
sys.stdout.flush()
train_batch_id += 1
t1 = time.time()
#NOTE: this for benchmark profiler
total_batch_num = total_batch_num + 1
if args.is_profiler and pass_id == 0 and train_batch_id == args.print_step:
profiler.start_profiler("All")
elif args.is_profiler and pass_id == 0 and train_batch_id == args.print_step + 5:
profiler.stop_profiler("total", args.profiler_path)
return
if args.use_dali:
train_iter.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册