提交 eee28b3e 编写于 作者: H hysunflower 提交者: Jinhua Liang

add profiler and max_iter for mask_rcnn (#3930)

上级 fe149e5a
...@@ -40,6 +40,7 @@ import collections ...@@ -40,6 +40,7 @@ import collections
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import reader import reader
import models.model_builder as model_builder import models.model_builder as model_builder
import models.resnet as resnet import models.resnet as resnet
...@@ -197,6 +198,14 @@ def train(): ...@@ -197,6 +198,14 @@ def train():
sys.stdout.flush() sys.stdout.flush()
if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0: if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id)) save_model("model_iter{}".format(iter_id))
#profiler tools, used for benchmark
if args.is_profiler and iter_id == 10:
profiler.start_profiler("All")
elif args.is_profiler and iter_id == 15:
profiler.stop_profiler("total", args.profiler_path)
return
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
last_loss = np.array(outs[0]).mean() last_loss = np.array(outs[0]).mean()
...@@ -232,6 +241,12 @@ def train(): ...@@ -232,6 +241,12 @@ def train():
save_model("model_iter{}".format(iter_id)) save_model("model_iter{}".format(iter_id))
if (iter_id + 1) == cfg.max_iter: if (iter_id + 1) == cfg.max_iter:
break break
#profiler tools, used for benchmark
if args.is_profiler and iter_id == 10:
profiler.start_profiler("All")
elif args.is_profiler and iter_id == 15:
profiler.stop_profiler("total", args.profiler_path)
return
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
last_loss = np.array(outs[0]).mean() last_loss = np.array(outs[0]).mean()
......
...@@ -149,6 +149,11 @@ def parse_args(): ...@@ -149,6 +149,11 @@ def parse_args():
add_arg('variance', float, [1.,1.,1.,1.], "The variance of anchors.") add_arg('variance', float, [1.,1.,1.,1.], "The variance of anchors.")
add_arg('rpn_stride', float, [16.,16.], "Stride of the feature map that RPN is attached.") add_arg('rpn_stride', float, [16.,16.], "Stride of the feature map that RPN is attached.")
add_arg('rpn_nms_thresh', float, 0.7, "NMS threshold used on RPN proposals") add_arg('rpn_nms_thresh', float, 0.7, "NMS threshold used on RPN proposals")
#NOTE: args for profiler, used for benchmark
add_arg('is_profiler', int, 0, "the profiler switch.(used for benchmark)")
add_arg('profiler_path', str, './', "the profiler output file path.(used for benchmark)")
# TRAIN VAL INFER # TRAIN VAL INFER
add_arg('MASK_ON', bool, False, "Option for different models. If False, choose faster_rcnn. If True, choose mask_rcnn") add_arg('MASK_ON', bool, False, "Option for different models. If False, choose faster_rcnn. If True, choose mask_rcnn")
add_arg('im_per_batch', int, 1, "Minibatch size.") add_arg('im_per_batch', int, 1, "Minibatch size.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册