提交 33b6550b 编写于 作者: W wangguanzhong 提交者: GitHub

fix eval for multi-scale test (#3592)

上级 053b1d37
...@@ -93,6 +93,7 @@ class MaskRCNN(object): ...@@ -93,6 +93,7 @@ class MaskRCNN(object):
for k, v in body_feats.items()) for k, v in body_feats.items())
# FPN # FPN
spatial_scale = None
if self.fpn is not None: if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats) body_feats, spatial_scale = self.fpn.get_output(body_feats)
......
...@@ -81,6 +81,9 @@ def main(): ...@@ -81,6 +81,9 @@ def main():
with fluid.program_guard(eval_prog, startup_prog): with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
pyreader, feed_vars = create_feed(eval_feed) pyreader, feed_vars = create_feed(eval_feed)
if multi_scale_test is None:
fetches = model.eval(feed_vars)
else:
fetches = model.eval(feed_vars, multi_scale_test) fetches = model.eval(feed_vars, multi_scale_test)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
......
...@@ -77,9 +77,6 @@ def main(): ...@@ -77,9 +77,6 @@ def main():
if 'log_iter' not in cfg: if 'log_iter' not in cfg:
cfg.log_iter = 20 cfg.log_iter = 20
if 'multi_scale_test' not in cfg:
cfg.multi_scale_test = False
ignore_params = cfg.finetune_exclude_pretrained_params \ ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else [] if 'finetune_exclude_pretrained_params' in cfg else []
...@@ -150,7 +147,7 @@ def main(): ...@@ -150,7 +147,7 @@ def main():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch) model = create(main_arch)
eval_pyreader, feed_vars = create_feed(eval_feed) eval_pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars, cfg.multi_scale_test) fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册