diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index d54d2e0859f4577aeacceeb7c152626a6c3a42ed..82a6ab428946a48afb2dbddb9a7a11ad3b79752d 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.ops.operators import RandomRot90 +from .ops.operators import format_data from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid @@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment): super().__init__(*args, **kwargs) self.prob = prob - def __call__(self, ori_data): - img = ori_data["img"] if isinstance(ori_data, dict) else ori_data + @format_data + def __call__(self, img): if not isinstance(img, Image.Image): img = np.ascontiguousarray(img) img = Image.fromarray(img) @@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment): img = super().__call__(img) if isinstance(img, Image.Image): img = np.asarray(img) - processed_data = { - ** - ori_data, - "img": img - } if isinstance(ori_data, dict) else img - return processed_data + + return img diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 454b85d3ffbd5d0d6e868c361114d9c9d5162d2a..d74b26a2c906320e3502e99358da7104cfb77a71 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -34,6 +34,23 @@ from .functional import augmentations from ppcls.utils import logger +def format_data(func): + def warpper(self, data): + if isinstance(data, dict): + img = data["img"] + result = func(self, img) + if not isinstance(result, dict): + result = {"img": result} + return { ** data, ** result} + else: + result = func(self, data) + if isinstance(result, dict): + result = result["img"] + return result + + return warpper + + class UnifiedResize(object): def __init__(self, interpolation=None, backend="cv2", return_numpy=True): _cv2_interp_from_str = { @@ -161,8 +178,8 @@ class DecodeImage(object): f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}." ) - def __call__(self, ori_data): - img = ori_data["img"] if isinstance(ori_data, dict) else ori_data + @format_data + def __call__(self, img): if isinstance(img, Image.Image): assert self.backend == "pil", "invalid input 'img' in DecodeImage" elif isinstance(img, np.ndarray): @@ -189,12 +206,7 @@ class DecodeImage(object): if self.channel_first: img = img.transpose((2, 0, 1)) - processed_data = { - ** - ori_data, - "img": img - } if isinstance(ori_data, dict) else img - return processed_data + return img class ResizeImage(object): @@ -421,8 +433,8 @@ class RandCropImage(object): self._resize_func = UnifiedResize( interpolation=interpolation, backend=backend) - def __call__(self, ori_data): - img = ori_data["img"] if isinstance(ori_data, dict) else ori_data + @format_data + def __call__(self, img): size = self.size scale = self.scale ratio = self.ratio @@ -447,12 +459,7 @@ class RandCropImage(object): j = random.randint(0, img_h - h) img = self._resize_func(img[j:j + h, i:i + w, :], size) - processed_data = { - ** - ori_data, - "img": img - } if isinstance(ori_data, dict) else img - return processed_data + return img class RandCropImageV2(object): @@ -557,8 +564,8 @@ class NormalizeImage(object): self.mean = np.array(mean).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32') - def __call__(self, ori_data): - img = ori_data["img"] if isinstance(ori_data, dict) else ori_data + @format_data + def __call__(self, img): from PIL import Image if isinstance(img, Image.Image): img = np.array(img) @@ -580,12 +587,7 @@ class NormalizeImage(object): (img, pad_zeros), axis=2)) img = img.astype(self.output_dtype) - processed_data = { - ** - ori_data, - "img": img - } if isinstance(ori_data, dict) else img - return processed_data + return img class ToCHWImage(object): @@ -772,15 +774,9 @@ class RandomRot90(object): def __init__(self): pass - def __call__(self, ori_data): - img = ori_data["img"] if isinstance(ori_data, dict) else ori_data + @format_data + def __call__(self, img): orientation = random.choice([0, 1, 2, 3]) if orientation: img = np.rot90(img, orientation) - processed_data = { - ** - ori_data, - "img": img, - "random_rot90_orientation": orientation - } if isinstance(ori_data, dict) else img - return processed_data + return {"img": img, "random_rot90_orientation": orientation} diff --git a/ppcls/data/preprocess/ops/random_erasing.py b/ppcls/data/preprocess/ops/random_erasing.py index 3fa4d1319dfc3bad9d1dc83dcaa1ed3d8bba9863..e687283c7a6695ade91a5601c6892456a388230c 100644 --- a/ppcls/data/preprocess/ops/random_erasing.py +++ b/ppcls/data/preprocess/ops/random_erasing.py @@ -22,6 +22,8 @@ import random import numpy as np +from .operators import format_data + class Pixels(object): def __init__(self, mode="const", mean=[0., 0., 0.]): @@ -70,11 +72,10 @@ class RandomErasing(object): self.attempt = attempt self.get_pixels = Pixels(mode, mean) - def __call__(self, ori_data): + @format_data + def __call__(self, img): if random.random() > self.EPSILON: - return ori_data - - img = ori_data["img"] if isinstance(ori_data, dict) else ori_data + return img for _ in range(self.attempt): if isinstance(img, np.ndarray): @@ -107,16 +108,6 @@ class RandomErasing(object): img[0, x1:x1 + h, y1:y1 + w] = pixels[0] else: img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0] - processed_data = { - ** - ori_data, - "img": img - } if isinstance(ori_data, dict) else img - return processed_data - - processed_data = { - ** - ori_data, - "img": img - } if isinstance(ori_data, dict) else img - return processed_data + return img + + return img