提交 eee28b3e 编写于 作者: H hysunflower 提交者: Jinhua Liang

add profiler and max_iter for mask_rcnn (#3930)

上级 fe149e5a
......@@ -40,6 +40,7 @@ import collections
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import reader
import models.model_builder as model_builder
import models.resnet as resnet
......@@ -197,6 +198,14 @@ def train():
sys.stdout.flush()
if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
#profiler tools, used for benchmark
if args.is_profiler and iter_id == 10:
profiler.start_profiler("All")
elif args.is_profiler and iter_id == 15:
profiler.stop_profiler("total", args.profiler_path)
return
end_time = time.time()
total_time = end_time - start_time
last_loss = np.array(outs[0]).mean()
......@@ -232,6 +241,12 @@ def train():
save_model("model_iter{}".format(iter_id))
if (iter_id + 1) == cfg.max_iter:
break
#profiler tools, used for benchmark
if args.is_profiler and iter_id == 10:
profiler.start_profiler("All")
elif args.is_profiler and iter_id == 15:
profiler.stop_profiler("total", args.profiler_path)
return
end_time = time.time()
total_time = end_time - start_time
last_loss = np.array(outs[0]).mean()
......
......@@ -149,6 +149,11 @@ def parse_args():
add_arg('variance', float, [1.,1.,1.,1.], "The variance of anchors.")
add_arg('rpn_stride', float, [16.,16.], "Stride of the feature map that RPN is attached.")
add_arg('rpn_nms_thresh', float, 0.7, "NMS threshold used on RPN proposals")
#NOTE: args for profiler, used for benchmark
add_arg('is_profiler', int, 0, "the profiler switch.(used for benchmark)")
add_arg('profiler_path', str, './', "the profiler output file path.(used for benchmark)")
# TRAIN VAL INFER
add_arg('MASK_ON', bool, False, "Option for different models. If False, choose faster_rcnn. If True, choose mask_rcnn")
add_arg('im_per_batch', int, 1, "Minibatch size.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册