From 3c1086f21c732b775a86b9865b1478bda0e4ca9c Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 30 Dec 2019 19:53:39 +0800 Subject: [PATCH] Refine import cocoapi (#143) * refine cocoapi import * fix infer when anno_path is None --- ppdet/data/source/coco.py | 2 +- ppdet/data/source/dataset.py | 2 ++ ppdet/utils/visualizer.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index 22a0d7e3a..0833a3227 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -14,7 +14,6 @@ import os import numpy as np -from pycocotools.coco import COCO from .dataset import DataSet from ppdet.core.workspace import register, serializable @@ -75,6 +74,7 @@ class COCODataSet(DataSet): assert anno_path.endswith('.json'), \ 'invalid coco annotation file: ' + anno_path + from pycocotools.coco import COCO coco = COCO(anno_path) img_ids = coco.getImgIds() cat_ids = coco.getCatIds() diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index 9f604fe9a..806dd2b12 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -75,6 +75,8 @@ class DataSet(object): return self.cname2cid def get_anno(self): + if self.anno_path is None: + return return os.path.join(self.dataset_dir, self.anno_path) def get_imid2path(self): diff --git a/ppdet/utils/visualizer.py b/ppdet/utils/visualizer.py index e27bad90b..0658c8c35 100644 --- a/ppdet/utils/visualizer.py +++ b/ppdet/utils/visualizer.py @@ -18,7 +18,6 @@ from __future__ import print_function from __future__ import unicode_literals import numpy as np -import pycocotools.mask as mask_util from PIL import Image, ImageDraw from .colormap import colormap @@ -56,6 +55,7 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7): segm, score = dt['segmentation'], dt['score'] if score < threshold: continue + import pycocotools.mask as mask_util mask = mask_util.decode(segm) * 255 color_mask = color_list[mask_color_id % len(color_list), 0:3] mask_color_id += 1 -- GitLab