未验证 提交 09a04584 编写于 作者: K Kaipeng Deng 提交者: GitHub

refine bbox_normalize in infer.py (#2781)

* refine bbox_normalize in infer.py

* add is_bbox_normalize

* rename _forward to build

* check callable
上级 811680c0
......@@ -57,7 +57,7 @@ class SSD(object):
if isinstance(metric, dict):
self.metric = SSDMetric(**metric)
def _forward(self, feed_vars, mode='train'):
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
if mode == 'train' or mode == 'eval':
gt_box = feed_vars['gt_box']
......@@ -88,10 +88,16 @@ class SSD(object):
return {'bbox': pred}
def train(self, feed_vars):
return self._forward(feed_vars, 'train')
return self.build(feed_vars, 'train')
def eval(self, feed_vars):
return self._forward(feed_vars, 'eval')
return self.build(feed_vars, 'eval')
def test(self, feed_vars):
return self._forward(feed_vars, 'test')
return self.build(feed_vars, 'test')
def is_bbox_normalized(self):
# SSD use output_decoder in output layers, bbox is normalized
# to range [0, 1], is_bbox_normalized is used in infer.py
return True
......@@ -164,6 +164,12 @@ def main():
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
imid2path = reader.imid2path
for iter_id, data in enumerate(reader()):
outs = exe.run(infer_prog,
......@@ -178,7 +184,6 @@ def main():
bbox_results = None
mask_results = None
is_bbox_normalized = True if cfg.metric == 'VOC' else False
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册