diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 9cdc58b2b5f6c4987c898a6549f5652797d1c845..8075ced904de51551c8946905f874e002178abba 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -38,7 +38,8 @@ class UnifiedResize(object): 'bilinear': cv2.INTER_LINEAR, 'area': cv2.INTER_AREA, 'bicubic': cv2.INTER_CUBIC, - 'lanczos': cv2.INTER_LANCZOS4 + 'lanczos': cv2.INTER_LANCZOS4, + 'random': (cv2.INTER_LINEAR, cv2.INTER_CUBIC) } _pil_interp_from_str = { 'nearest': Image.NEAREST, @@ -46,10 +47,18 @@ class UnifiedResize(object): 'bicubic': Image.BICUBIC, 'box': Image.BOX, 'lanczos': Image.LANCZOS, - 'hamming': Image.HAMMING + 'hamming': Image.HAMMING, + 'random': (Image.BILINEAR, Image.BICUBIC) } + def _cv2_resize(src, size, resample): + if isinstance(resample, tuple): + resample = random.choice(resample) + return cv2.resize(src, size, interpolation=resample) + def _pil_resize(src, size, resample): + if isinstance(resample, tuple): + resample = random.choice(resample) pil_img = Image.fromarray(src) pil_img = pil_img.resize(size, resample) return np.asarray(pil_img) @@ -60,7 +69,7 @@ class UnifiedResize(object): # compatible with opencv < version 4.4.0 elif interpolation is None: interpolation = cv2.INTER_LINEAR - self.resize_func = partial(cv2.resize, interpolation=interpolation) + self.resize_func = partial(_cv2_resize, resample=interpolation) elif backend.lower() == "pil": if isinstance(interpolation, str): interpolation = _pil_interp_from_str[interpolation.lower()]