diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index 22a0d7e3a5b22ec9b5df4e3dc18ebbab882be8d1..0833a322797b19554fa6f98da62ea00ac1d49b96 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -14,7 +14,6 @@ import os import numpy as np -from pycocotools.coco import COCO from .dataset import DataSet from ppdet.core.workspace import register, serializable @@ -75,6 +74,7 @@ class COCODataSet(DataSet): assert anno_path.endswith('.json'), \ 'invalid coco annotation file: ' + anno_path + from pycocotools.coco import COCO coco = COCO(anno_path) img_ids = coco.getImgIds() cat_ids = coco.getCatIds() diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index 9f604fe9ab20c60c3634b25895ac5564b633b6b2..806dd2b12423d6bddd09ba9e50566bb33c1a3bd1 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -75,6 +75,8 @@ class DataSet(object): return self.cname2cid def get_anno(self): + if self.anno_path is None: + return return os.path.join(self.dataset_dir, self.anno_path) def get_imid2path(self): diff --git a/ppdet/utils/visualizer.py b/ppdet/utils/visualizer.py index e27bad90b9dbc250bc0962d623565387bb1aa97d..0658c8c355db67aa54c4e461b5eeb40506d668bc 100644 --- a/ppdet/utils/visualizer.py +++ b/ppdet/utils/visualizer.py @@ -18,7 +18,6 @@ from __future__ import print_function from __future__ import unicode_literals import numpy as np -import pycocotools.mask as mask_util from PIL import Image, ImageDraw from .colormap import colormap @@ -56,6 +55,7 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7): segm, score = dt['segmentation'], dt['score'] if score < threshold: continue + import pycocotools.mask as mask_util mask = mask_util.decode(segm) * 255 color_mask = color_list[mask_color_id % len(color_list), 0:3] mask_color_id += 1