diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index b706dbd9a0ef87c29131b5c6f2dcb7c0a89d45d4..e617b8a71afffeb9e18e4be412f5a3374bd387ec 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -140,13 +140,12 @@ class DecodeImage(object): """ decode image """ def __init__(self, + to_np=True, to_rgb=True, - to_np=False, channel_first=False, - backend="cv2", - return_numpy=True): - self.to_rgb = to_rgb + backend="cv2"): self.to_np = to_np # to numpy + self.to_rgb = to_rgb # only enabled when to_np is True self.channel_first = channel_first # only enabled when to_np is True if backend.lower() not in ["cv2", "pil"]: @@ -156,38 +155,33 @@ class DecodeImage(object): backend = "cv2" self.backend = backend.lower() - if not return_numpy: - assert to_rgb, f"\"to_rgb\" must be True while \"return_numpy\" is False." - assert not channel_first, f"\"channel_first\" must be False while \"return_numpy\" is False." - self.return_numpy = return_numpy + if not to_np: + logger.warning( + f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}." + ) def __call__(self, img): if isinstance(img, Image.Image): - if self.return_numpy: - img = np.asarray(img)[:, :, ::-1] # to bgr + assert self.backend == "pil", "invalid input 'img' in DecodeImage" elif isinstance(img, np.ndarray): - assert self.return_numpy, "invalid input 'img' in DecodeImage" - else: - if six.PY2: - assert type(img) is str and len( - img) > 0, "invalid input 'img' in DecodeImage" - else: - assert type(img) is bytes and len( - img) > 0, "invalid input 'img' in DecodeImage" - + assert self.backend == "cv2", "invalid input 'img' in DecodeImage" + elif isinstance(img, bytes): if self.backend == "pil": data = io.BytesIO(img) - img = Image.open(data).convert("RGB") - if self.return_numpy: - img = np.asarray(img)[:, :, ::-1] # to bgr + img = Image.open(data) else: - data = np.frombuffer(img, dtype='uint8') + data = np.frombuffer(img, dtype="uint8") img = cv2.imdecode(data, 1) + else: + raise ValueError("invalid input 'img' in DecodeImage") + + if self.to_np: + if self.backend == "pil": + assert img.mode == "RGB", f"invalid shape of image[{img.shape}]" + img = np.asarray(img)[:, :, ::-1] # BRG - if self.return_numpy: if self.to_rgb: - assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( - img.shape) + assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" img = img[:, :, ::-1] if self.channel_first: