提交 a4bbb799 编写于 作者: H hysunflower 提交者: Jinhua Liang

Add profiler image classification2 (#3917)

* add profiler and max_iter for yolov3 model

* modify image_classification train

* update train.py

* modify args position
上级 e61551b6
......@@ -38,6 +38,7 @@ set_paddle_flags({
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import reader
from utils import *
import models
......@@ -188,6 +189,7 @@ def train(args):
compiled_train_prog = best_strategy_compiled(args, train_prog,
train_fetch_vars[0], exe)
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
total_batch_num = 0 #this is for benchmark
for pass_id in range(args.num_epochs):
if num_trainers > 1:
imagenet_reader.set_shuffle_seed(pass_id + (
......@@ -197,9 +199,10 @@ def train(args):
train_batch_metrics_record = []
train_data_loader.start()
try:
while True:
if args.max_iter and total_batch_num == args.max_iter:
return
t1 = time.time()
train_batch_metrics = exe.run(compiled_train_prog,
fetch_list=train_fetch_list)
......@@ -215,11 +218,19 @@ def train(args):
"batch")
sys.stdout.flush()
train_batch_id += 1
total_batch_num = total_batch_num + 1 #this is for benchmark
##profiler tools
if args.is_profiler and pass_id == 0 and train_batch_id == 100:
profiler.start_profiler("All")
elif args.is_profiler and pass_id == 0 and train_batch_id == 150:
profiler.stop_profiler("total", args.profiler_path)
return
except fluid.core.EOFException:
train_data_loader.reset()
if trainer_id == 0:
if trainer_id == 0 and args.validate:
if args.use_ema:
print('ExponentialMovingAverage validate start...')
with ema.apply(exe):
......
......@@ -142,7 +142,11 @@ def parse_args():
add_arg('padding_type', str, "SAME", "Padding type of convolution")
add_arg('use_se', bool, True, "Whether to use Squeeze-and-Excitation module for EfficientNet.")
# yapf: enable
#NOTE: args for profiler
add_arg('is_profiler', int, 0, "the profiler switch.(used for benchmark)")
add_arg('profiler_path', str, './', "the profiler output file path.(used for benchmark)")
add_arg('max_iter', int, 0, "the max train batch num.(used for benchmark)")
add_arg('validate', int, 1, "whether validate.(used for benchmark)")
args = parser.parse_args()
return args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册