未验证 提交 94a2f9fb 编写于 作者: G Guanghua Yu 提交者: GitHub

fix `with_background` in export model (#633)

* fix with_background in export model

* fix coco_eval background
上级 e3fce291
...@@ -283,8 +283,6 @@ class Config(): ...@@ -283,8 +283,6 @@ class Config():
self.use_python_inference = yml_conf['use_python_inference'] self.use_python_inference = yml_conf['use_python_inference']
self.min_subgraph_size = yml_conf['min_subgraph_size'] self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list'] self.labels = yml_conf['label_list']
if not yml_conf['with_background']:
self.labels = self.labels[1:]
self.mask_resolution = None self.mask_resolution = None
if 'mask_resolution' in yml_conf: if 'mask_resolution' in yml_conf:
self.mask_resolution = yml_conf['mask_resolution'] self.mask_resolution = yml_conf['mask_resolution']
......
...@@ -425,7 +425,9 @@ def get_category_info_from_anno(anno_file, with_background=True): ...@@ -425,7 +425,9 @@ def get_category_info_from_anno(anno_file, with_background=True):
for i, cat in enumerate(cats) for i, cat in enumerate(cats)
} }
catid2name = {cat['id']: cat['name'] for cat in cats} catid2name = {cat['id']: cat['name'] for cat in cats}
if with_background:
clsid2catid.update({0: 0})
catid2name.update({0: 'background'})
return clsid2catid, catid2name return clsid2catid, catid2name
...@@ -607,5 +609,7 @@ def coco17_category_info(with_background=True): ...@@ -607,5 +609,7 @@ def coco17_category_info(with_background=True):
if not with_background: if not with_background:
clsid2catid = {k - 1: v for k, v in clsid2catid.items()} clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
else:
clsid2catid.update({0: 0})
return clsid2catid, catid2name return clsid2catid, catid2name
...@@ -61,6 +61,7 @@ def parse_reader(reader_cfg, metric, arch): ...@@ -61,6 +61,7 @@ def parse_reader(reader_cfg, metric, arch):
metric)) metric))
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
label_list = [str(cat) for cat in catid2name.values()] label_list = [str(cat) for cat in catid2name.values()]
sample_transforms = reader_cfg['sample_transforms'] sample_transforms = reader_cfg['sample_transforms']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册