diff --git a/paddlex/__init__.py b/paddlex/__init__.py index 14e56b432f6d3a86dbfd787f1b171229b3eb74e7..d41fbf8381938b7a4a950671c17ed2fb6dea2329 100644 --- a/paddlex/__init__.py +++ b/paddlex/__init__.py @@ -20,6 +20,12 @@ from . import seg from . import cls from . import slim +try: + import pycocotools +except: + print("[WARNING] pycocotools is not installed, detection model is not available now.") + print("[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md") + env_info = get_environ_info() load_model = cv.models.load_model datasets = cv.datasets diff --git a/paddlex/cv/datasets/coco.py b/paddlex/cv/datasets/coco.py index 52a3336ce60334bd95fa0d84d1adfb164dfb686b..dc9946ac784a296dfc2751ef0a46bc2496aca2e4 100644 --- a/paddlex/cv/datasets/coco.py +++ b/paddlex/cv/datasets/coco.py @@ -19,7 +19,6 @@ import random import numpy as np import paddlex.utils.logging as logging import paddlex as pst -from pycocotools.coco import COCO from .voc import VOCDetection from .dataset import is_pic @@ -47,6 +46,8 @@ class CocoDetection(VOCDetection): buffer_size=100, parallel_method='process', shuffle=False): + from pycocotools.coco import COCO + super(VOCDetection, self).__init__( transforms=transforms, num_workers=num_workers, diff --git a/paddlex/cv/datasets/voc.py b/paddlex/cv/datasets/voc.py index 6ab985fed760001f06d499987baf5d5c6b4dd049..db63d9828d03b95569382f7d059df5567a9ce59d 100644 --- a/paddlex/cv/datasets/voc.py +++ b/paddlex/cv/datasets/voc.py @@ -18,7 +18,6 @@ import os.path as osp import random import numpy as np import xml.etree.ElementTree as ET -from pycocotools.coco import COCO import paddlex.utils.logging as logging from .dataset import Dataset from .dataset import is_pic @@ -51,6 +50,7 @@ class VOCDetection(Dataset): buffer_size=100, parallel_method='process', shuffle=False): + from pycocotools.coco import COCO super(VOCDetection, self).__init__( transforms=transforms, num_workers=num_workers, diff --git a/paddlex/cv/models/slim/visualize.py b/paddlex/cv/models/slim/visualize.py index 5fcbad9865c3a356ef514098fa236112a1cb3169..d9380abb2f1184cfe59d77b84d6841b5c4fd7288 100644 --- a/paddlex/cv/models/slim/visualize.py +++ b/paddlex/cv/models/slim/visualize.py @@ -15,9 +15,6 @@ import os.path as osp import tqdm import numpy as np -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt from .prune import cal_model_size from paddleslim.prune import load_sensitivities @@ -30,6 +27,10 @@ def visualize(model, sensitivities_file, save_dir='./'): model (paddlex.cv.models): paddlex中的模型。 sensitivities_file (str): 敏感度文件存储路径。 """ + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + program = model.test_prog place = model.places[0] fig = plt.figure() diff --git a/paddlex/cv/models/utils/visualize.py b/paddlex/cv/models/utils/visualize.py index 38de5b41aa6658e2ccf7e8c0fec7e8f5a291ecc5..533ea99ff50b097958880e58213c4810425a96f1 100644 --- a/paddlex/cv/models/utils/visualize.py +++ b/paddlex/cv/models/utils/visualize.py @@ -15,7 +15,6 @@ import os import cv2 import numpy as np -import matplotlib.pyplot as plt from PIL import Image, ImageDraw import paddlex.utils.logging as logging @@ -222,6 +221,7 @@ def draw_pr_curve(eval_details_file=None, return mean_s def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'): + import matplotlib.pyplot as plt from pycocotools.cocoeval import COCOeval coco_dt = loadRes(coco_gt, coco_dt) np.linspace = fixed_linspace diff --git a/setup.py b/setup.py index db8c2a8a17420bc56f98d528a406284885d14df9..99b91b86b1442a23515bfc93c326ec1498d67e27 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ setuptools.setup( packages=setuptools.find_packages(), setup_requires=['cython', 'numpy', 'sklearn'], install_requires=[ - 'pycocotools', 'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0', + "pycocotools;platform_system!='Windows'", + 'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0', 'paddleslim==1.0.1', 'paddlehub>=1.6.2' ], classifiers=[