diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 8075ced904de51551c8946905f874e002178abba..e9390d06bf2b4dadc5b900430a9212ea258f7d5f 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -18,6 +18,7 @@ from __future__ import print_function from __future__ import unicode_literals from functools import partial +import io import six import math import random @@ -93,11 +94,22 @@ class OperatorParamError(ValueError): class DecodeImage(object): """ decode image """ - def __init__(self, to_rgb=True, to_np=False, channel_first=False): + def __init__(self, + to_rgb=True, + to_np=False, + channel_first=False, + backend="cv2"): self.to_rgb = to_rgb self.to_np = to_np # to numpy self.channel_first = channel_first # only enabled when to_np is True + if backend.lower() not in ["cv2", "pil"]: + logger.warning( + f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." + ) + backend = "cv2" + self.backend = backend.lower() + def __call__(self, img): if six.PY2: assert type(img) is str and len( @@ -105,8 +117,15 @@ class DecodeImage(object): else: assert type(img) is bytes and len( img) > 0, "invalid input 'img' in DecodeImage" - data = np.frombuffer(img, dtype='uint8') - img = cv2.imdecode(data, 1) + + if self.backend == "pil": + data = io.BytesIO(img) + img = Image.open(data).convert("RGB") + img = np.asarray(img)[:, :, ::-1] # to bgr + else: + data = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(data, 1) + if self.to_rgb: assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( img.shape)