未验证 提交 3c1086f2 编写于 作者: W wangguanzhong 提交者: GitHub

Refine import cocoapi (#143)

* refine cocoapi import

* fix infer when anno_path is None
上级 f9d57fcf
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
import numpy as np import numpy as np
from pycocotools.coco import COCO
from .dataset import DataSet from .dataset import DataSet
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
...@@ -75,6 +74,7 @@ class COCODataSet(DataSet): ...@@ -75,6 +74,7 @@ class COCODataSet(DataSet):
assert anno_path.endswith('.json'), \ assert anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path 'invalid coco annotation file: ' + anno_path
from pycocotools.coco import COCO
coco = COCO(anno_path) coco = COCO(anno_path)
img_ids = coco.getImgIds() img_ids = coco.getImgIds()
cat_ids = coco.getCatIds() cat_ids = coco.getCatIds()
......
...@@ -75,6 +75,8 @@ class DataSet(object): ...@@ -75,6 +75,8 @@ class DataSet(object):
return self.cname2cid return self.cname2cid
def get_anno(self): def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path) return os.path.join(self.dataset_dir, self.anno_path)
def get_imid2path(self): def get_imid2path(self):
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy as np import numpy as np
import pycocotools.mask as mask_util
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from .colormap import colormap from .colormap import colormap
...@@ -56,6 +55,7 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7): ...@@ -56,6 +55,7 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7):
segm, score = dt['segmentation'], dt['score'] segm, score = dt['segmentation'], dt['score']
if score < threshold: if score < threshold:
continue continue
import pycocotools.mask as mask_util
mask = mask_util.decode(segm) * 255 mask = mask_util.decode(segm) * 255
color_mask = color_list[mask_color_id % len(color_list), 0:3] color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1 mask_color_id += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册