提交 3f84dcb4 编写于 作者: F flytocc

add random interpolation for UnifiedResize

上级 c3018ebd
...@@ -38,7 +38,8 @@ class UnifiedResize(object): ...@@ -38,7 +38,8 @@ class UnifiedResize(object):
'bilinear': cv2.INTER_LINEAR, 'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA, 'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC, 'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4 'lanczos': cv2.INTER_LANCZOS4,
'random': (cv2.INTER_LINEAR, cv2.INTER_CUBIC)
} }
_pil_interp_from_str = { _pil_interp_from_str = {
'nearest': Image.NEAREST, 'nearest': Image.NEAREST,
...@@ -46,10 +47,18 @@ class UnifiedResize(object): ...@@ -46,10 +47,18 @@ class UnifiedResize(object):
'bicubic': Image.BICUBIC, 'bicubic': Image.BICUBIC,
'box': Image.BOX, 'box': Image.BOX,
'lanczos': Image.LANCZOS, 'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING 'hamming': Image.HAMMING,
'random': (Image.BILINEAR, Image.BICUBIC)
} }
def _cv2_resize(src, size, resample):
if isinstance(resample, tuple):
resample = random.choice(resample)
return cv2.resize(src, size, interpolation=resample)
def _pil_resize(src, size, resample): def _pil_resize(src, size, resample):
if isinstance(resample, tuple):
resample = random.choice(resample)
pil_img = Image.fromarray(src) pil_img = Image.fromarray(src)
pil_img = pil_img.resize(size, resample) pil_img = pil_img.resize(size, resample)
return np.asarray(pil_img) return np.asarray(pil_img)
...@@ -60,7 +69,7 @@ class UnifiedResize(object): ...@@ -60,7 +69,7 @@ class UnifiedResize(object):
# compatible with opencv < version 4.4.0 # compatible with opencv < version 4.4.0
elif interpolation is None: elif interpolation is None:
interpolation = cv2.INTER_LINEAR interpolation = cv2.INTER_LINEAR
self.resize_func = partial(cv2.resize, interpolation=interpolation) self.resize_func = partial(_cv2_resize, resample=interpolation)
elif backend.lower() == "pil": elif backend.lower() == "pil":
if isinstance(interpolation, str): if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()] interpolation = _pil_interp_from_str[interpolation.lower()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册