提交 c22fb053 编写于 作者: J jiangjiajun

remove pycocotools dependency

上级 ece29fe5
...@@ -20,6 +20,12 @@ from . import seg ...@@ -20,6 +20,12 @@ from . import seg
from . import cls from . import cls
from . import slim from . import slim
try:
import pycocotools
except:
print("[WARNING] pycocotools is not installed, detection model is not available now.")
print("[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md")
env_info = get_environ_info() env_info = get_environ_info()
load_model = cv.models.load_model load_model = cv.models.load_model
datasets = cv.datasets datasets = cv.datasets
......
...@@ -19,7 +19,6 @@ import random ...@@ -19,7 +19,6 @@ import random
import numpy as np import numpy as np
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
import paddlex as pst import paddlex as pst
from pycocotools.coco import COCO
from .voc import VOCDetection from .voc import VOCDetection
from .dataset import is_pic from .dataset import is_pic
...@@ -47,6 +46,8 @@ class CocoDetection(VOCDetection): ...@@ -47,6 +46,8 @@ class CocoDetection(VOCDetection):
buffer_size=100, buffer_size=100,
parallel_method='process', parallel_method='process',
shuffle=False): shuffle=False):
from pycocotools.coco import COCO
super(VOCDetection, self).__init__( super(VOCDetection, self).__init__(
transforms=transforms, transforms=transforms,
num_workers=num_workers, num_workers=num_workers,
......
...@@ -18,7 +18,6 @@ import os.path as osp ...@@ -18,7 +18,6 @@ import os.path as osp
import random import random
import numpy as np import numpy as np
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from pycocotools.coco import COCO
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
from .dataset import Dataset from .dataset import Dataset
from .dataset import is_pic from .dataset import is_pic
...@@ -51,6 +50,7 @@ class VOCDetection(Dataset): ...@@ -51,6 +50,7 @@ class VOCDetection(Dataset):
buffer_size=100, buffer_size=100,
parallel_method='process', parallel_method='process',
shuffle=False): shuffle=False):
from pycocotools.coco import COCO
super(VOCDetection, self).__init__( super(VOCDetection, self).__init__(
transforms=transforms, transforms=transforms,
num_workers=num_workers, num_workers=num_workers,
......
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
import os.path as osp import os.path as osp
import tqdm import tqdm
import numpy as np import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from .prune import cal_model_size from .prune import cal_model_size
from paddleslim.prune import load_sensitivities from paddleslim.prune import load_sensitivities
...@@ -30,6 +27,10 @@ def visualize(model, sensitivities_file, save_dir='./'): ...@@ -30,6 +27,10 @@ def visualize(model, sensitivities_file, save_dir='./'):
model (paddlex.cv.models): paddlex中的模型。 model (paddlex.cv.models): paddlex中的模型。
sensitivities_file (str): 敏感度文件存储路径。 sensitivities_file (str): 敏感度文件存储路径。
""" """
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
program = model.test_prog program = model.test_prog
place = model.places[0] place = model.places[0]
fig = plt.figure() fig = plt.figure()
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import os import os
import cv2 import cv2
import numpy as np import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
...@@ -222,6 +221,7 @@ def draw_pr_curve(eval_details_file=None, ...@@ -222,6 +221,7 @@ def draw_pr_curve(eval_details_file=None,
return mean_s return mean_s
def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'): def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
import matplotlib.pyplot as plt
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
coco_dt = loadRes(coco_gt, coco_dt) coco_dt = loadRes(coco_gt, coco_dt)
np.linspace = fixed_linspace np.linspace = fixed_linspace
......
...@@ -29,7 +29,8 @@ setuptools.setup( ...@@ -29,7 +29,8 @@ setuptools.setup(
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
setup_requires=['cython', 'numpy', 'sklearn'], setup_requires=['cython', 'numpy', 'sklearn'],
install_requires=[ install_requires=[
'pycocotools', 'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0', "pycocotools;platform_system!='Windows'",
'pyyaml', 'colorama', 'tqdm', 'visualdl==1.3.0',
'paddleslim==1.0.1', 'paddlehub>=1.6.2' 'paddleslim==1.0.1', 'paddlehub>=1.6.2'
], ],
classifiers=[ classifiers=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册