From 9df43751f58f4eda560bdaa1f9d2983a062f98ea Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Wed, 19 Aug 2020 03:44:39 +0000 Subject: [PATCH] add ppyolo in deploy --- paddlex/deploy.py | 12 +++++++----- requirements.txt | 1 + setup.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddlex/deploy.py b/paddlex/deploy.py index c715af9..ced22ae 100644 --- a/paddlex/deploy.py +++ b/paddlex/deploy.py @@ -19,7 +19,9 @@ import yaml import paddlex import paddle.fluid as fluid from paddlex.cv.transforms import build_transforms -from paddlex.cv.models import BaseClassifier, YOLOv3, FasterRCNN, MaskRCNN, DeepLabv3p +from paddlex.cv.models import BaseClassifier +from paddlex.cv.models import PPYOLO, FasterRCNN, MaskRCNN +from paddlex.cv.models import DeepLabv3p class Predictor: @@ -129,8 +131,8 @@ class Predictor: thread_num=thread_num) res['image'] = im elif self.model_type == "detector": - if self.model_name == "YOLOv3": - im, im_size = YOLOv3._preprocess( + if self.model_name in ["PPYOLO", "YOLOv3"]: + im, im_size = PPYOLO._preprocess( image, self.transforms, self.model_type, @@ -190,8 +192,8 @@ class Predictor: res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), } res['im_id'] = (np.array( [[i] for i in range(batch_size)]).astype('int32'), [[]]) - if self.model_name == "YOLOv3": - preds = YOLOv3._postprocess(res, batch_size, self.num_classes, + if self.model_name in ["PPYOLO", "YOLOv3"]: + preds = PPYOLO._postprocess(res, batch_size, self.num_classes, self.labels) elif self.model_name == "FasterRCNN": preds = FasterRCNN._postprocess(res, batch_size, diff --git a/requirements.txt b/requirements.txt index f7804c2..2e290c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ paddleslim == 1.0.1 shapely x2paddle paddlepaddle-gpu +opencv-python diff --git a/setup.py b/setup.py index 3046353..edcee85 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setuptools.setup( install_requires=[ "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm', 'paddleslim==1.0.1', 'visualdl>=2.0.0b', 'paddlehub>=1.6.2', - 'shapely>=1.7.0' + 'shapely>=1.7.0', "opencv-python" ], classifiers=[ "Programming Language :: Python :: 3", -- GitLab