提交 db216ad2 编写于 作者: H hysunflower 提交者: Jinhua Liang

add profiler and max_iter for yolov3 model (#3908)

上级 f139a89c
...@@ -41,6 +41,7 @@ from utility import (parse_args, print_arguments, ...@@ -41,6 +41,7 @@ from utility import (parse_args, print_arguments,
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import reader import reader
from models.yolov3 import YOLOv3 from models.yolov3 import YOLOv3
from learning_rate import exponential_with_warmup_decay from learning_rate import exponential_with_warmup_decay
...@@ -186,6 +187,13 @@ def train(): ...@@ -186,6 +187,13 @@ def train():
iter_id, lr[0], iter_id, lr[0],
smoothed_loss.get_mean_value(), start_time - prev_start_time)) smoothed_loss.get_mean_value(), start_time - prev_start_time))
sys.stdout.flush() sys.stdout.flush()
#add profiler tools
if args.is_profiler and iter_id == 5:
profiler.start_profiler("All")
elif args.is_profiler and iter_id == 10:
profiler.stop_profiler("total", args.profiler_path)
return
if (iter_id + 1) % cfg.snapshot_iter == 0: if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id)) save_model("model_iter{}".format(iter_id))
print("Snapshot {} saved, average loss: {}, \ print("Snapshot {} saved, average loss: {}, \
......
...@@ -146,6 +146,9 @@ def parse_args(): ...@@ -146,6 +146,9 @@ def parse_args():
add_arg('draw_thresh', float, 0.5, add_arg('draw_thresh', float, 0.5,
"Confidence score threshold to draw prediction box in image in debug mode") "Confidence score threshold to draw prediction box in image in debug mode")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
# args for profiler tools
add_arg('is_profiler', int, 0, "the switch of profiler")
add_arg('profiler_path', str, './', "the path to save profiler output files")
# yapf: enable # yapf: enable
args = parser.parse_args() args = parser.parse_args()
file_name = sys.argv[0] file_name = sys.argv[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册