提交 0fc149e6 编写于 作者: K Kaipeng Deng 提交者: GitHub

refine bbox_normalize in infer.py (#2781)

* refine bbox_normalize in infer.py

* add is_bbox_normalize

* rename _forward to build

* check callable
上级 a0bd06b8
...@@ -56,8 +56,8 @@ class SSD(object): ...@@ -56,8 +56,8 @@ class SSD(object):
self.output_decoder = SSDOutputDecoder(**output_decoder) self.output_decoder = SSDOutputDecoder(**output_decoder)
if isinstance(metric, dict): if isinstance(metric, dict):
self.metric = SSDMetric(**metric) self.metric = SSDMetric(**metric)
def _forward(self, feed_vars, mode='train'): def build(self, feed_vars, mode='train'):
im = feed_vars['image'] im = feed_vars['image']
if mode == 'train' or mode == 'eval': if mode == 'train' or mode == 'eval':
gt_box = feed_vars['gt_box'] gt_box = feed_vars['gt_box']
...@@ -88,10 +88,16 @@ class SSD(object): ...@@ -88,10 +88,16 @@ class SSD(object):
return {'bbox': pred} return {'bbox': pred}
def train(self, feed_vars): def train(self, feed_vars):
return self._forward(feed_vars, 'train') return self.build(feed_vars, 'train')
def eval(self, feed_vars): def eval(self, feed_vars):
return self._forward(feed_vars, 'eval') return self.build(feed_vars, 'eval')
def test(self, feed_vars): def test(self, feed_vars):
return self._forward(feed_vars, 'test') return self.build(feed_vars, 'test')
def is_bbox_normalized(self):
# SSD use output_decoder in output layers, bbox is normalized
# to range [0, 1], is_bbox_normalized is used in infer.py
return True
...@@ -164,6 +164,12 @@ def main(): ...@@ -164,6 +164,12 @@ def main():
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
imid2path = reader.imid2path imid2path = reader.imid2path
for iter_id, data in enumerate(reader()): for iter_id, data in enumerate(reader()):
outs = exe.run(infer_prog, outs = exe.run(infer_prog,
...@@ -178,7 +184,6 @@ def main(): ...@@ -178,7 +184,6 @@ def main():
bbox_results = None bbox_results = None
mask_results = None mask_results = None
is_bbox_normalized = True if cfg.metric == 'VOC' else False
if 'bbox' in res: if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res: if 'mask' in res:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册