提交 64c37000 编写于 作者: G gaotingquan

feat: support pil resize

Support PIL resizse with PIL interpolation to train transformer.
Almost all vision transformer models need using PIL.Image.BICUBIC
as interpolation in resize.
上级 74622af4
...@@ -19,12 +19,14 @@ from __future__ import division ...@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from functools import partial
import six import six
import math import math
import random import random
import cv2 import cv2
import numpy as np import numpy as np
import importlib import importlib
from PIL import Image
from python.det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize from python.det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
...@@ -50,6 +52,47 @@ def create_operators(params): ...@@ -50,6 +52,47 @@ def create_operators(params):
return ops return ops
class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2"):
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4
}
_pil_interp_from_str = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
def _pil_resize(src, size, resample):
pil_img = Image.fromarray(src)
pil_img = pil_img.resize(size, resample)
return np.asarray(pil_img)
if backend.lower() == "cv2":
if isinstance(interpolation, str):
interpolation = _cv2_interp_from_str[interpolation.lower()]
self.resize_func = partial(cv2.resize, interpolation=interpolation)
elif backend.lower() == "pil":
if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()]
self.resize_func = partial(_pil_resize, resample=interpolation)
else:
logger.warning(
f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
)
self.resize_func = cv2.resize
def __call__(self, src, size):
return self.resize_func(src, size)
class OperatorParamError(ValueError): class OperatorParamError(ValueError):
""" OperatorParamError """ OperatorParamError
""" """
...@@ -87,8 +130,11 @@ class DecodeImage(object): ...@@ -87,8 +130,11 @@ class DecodeImage(object):
class ResizeImage(object): class ResizeImage(object):
""" resize image """ """ resize image """
def __init__(self, size=None, resize_short=None, interpolation=-1): def __init__(self,
self.interpolation = interpolation if interpolation >= 0 else None size=None,
resize_short=None,
interpolation=None,
backend="cv2"):
if resize_short is not None and resize_short > 0: if resize_short is not None and resize_short > 0:
self.resize_short = resize_short self.resize_short = resize_short
self.w = None self.w = None
...@@ -101,6 +147,9 @@ class ResizeImage(object): ...@@ -101,6 +147,9 @@ class ResizeImage(object):
raise OperatorParamError("invalid params for ReisizeImage for '\ raise OperatorParamError("invalid params for ReisizeImage for '\
'both 'size' and 'resize_short' are None") 'both 'size' and 'resize_short' are None")
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, img): def __call__(self, img):
img_h, img_w = img.shape[:2] img_h, img_w = img.shape[:2]
if self.resize_short is not None: if self.resize_short is not None:
...@@ -110,10 +159,7 @@ class ResizeImage(object): ...@@ -110,10 +159,7 @@ class ResizeImage(object):
else: else:
w = self.w w = self.w
h = self.h h = self.h
if self.interpolation is None: return self._resize_func(img, (w, h))
return cv2.resize(img, (w, h))
else:
return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object): class CropImage(object):
...@@ -145,9 +191,12 @@ class CropImage(object): ...@@ -145,9 +191,12 @@ class CropImage(object):
class RandCropImage(object): class RandCropImage(object):
""" random crop image """ """ random crop image """
def __init__(self, size, scale=None, ratio=None, interpolation=-1): def __init__(self,
size,
self.interpolation = interpolation if interpolation >= 0 else None scale=None,
ratio=None,
interpolation=None,
backend="cv2"):
if type(size) is int: if type(size) is int:
self.size = (size, size) # (h, w) self.size = (size, size) # (h, w)
else: else:
...@@ -156,6 +205,9 @@ class RandCropImage(object): ...@@ -156,6 +205,9 @@ class RandCropImage(object):
self.scale = [0.08, 1.0] if scale is None else scale self.scale = [0.08, 1.0] if scale is None else scale
self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, img): def __call__(self, img):
size = self.size size = self.size
scale = self.scale scale = self.scale
...@@ -181,10 +233,8 @@ class RandCropImage(object): ...@@ -181,10 +233,8 @@ class RandCropImage(object):
j = random.randint(0, img_h - h) j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :] img = img[j:j + h, i:i + w, :]
if self.interpolation is None:
return cv2.resize(img, size) return self._resize_func(img, size)
else:
return cv2.resize(img, size, interpolation=self.interpolation)
class RandFlipImage(object): class RandFlipImage(object):
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from functools import partial
import six import six
import math import math
import random import random
...@@ -28,6 +29,48 @@ from PIL import Image ...@@ -28,6 +29,48 @@ from PIL import Image
from .autoaugment import ImageNetPolicy from .autoaugment import ImageNetPolicy
from .functional import augmentations from .functional import augmentations
from ppcls.utils import logger
class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2"):
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4
}
_pil_interp_from_str = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
def _pil_resize(src, size, resample):
pil_img = Image.fromarray(src)
pil_img = pil_img.resize(size, resample)
return np.asarray(pil_img)
if backend.lower() == "cv2":
if isinstance(interpolation, str):
interpolation = _cv2_interp_from_str[interpolation.lower()]
self.resize_func = partial(cv2.resize, interpolation=interpolation)
elif backend.lower() == "pil":
if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()]
self.resize_func = partial(_pil_resize, resample=interpolation)
else:
logger.warning(
f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
)
self.resize_func = cv2.resize
def __call__(self, src, size):
return self.resize_func(src, size)
class OperatorParamError(ValueError): class OperatorParamError(ValueError):
...@@ -67,8 +110,11 @@ class DecodeImage(object): ...@@ -67,8 +110,11 @@ class DecodeImage(object):
class ResizeImage(object): class ResizeImage(object):
""" resize image """ """ resize image """
def __init__(self, size=None, resize_short=None, interpolation=-1): def __init__(self,
self.interpolation = interpolation if interpolation >= 0 else None size=None,
resize_short=None,
interpolation=None,
backend="cv2"):
if resize_short is not None and resize_short > 0: if resize_short is not None and resize_short > 0:
self.resize_short = resize_short self.resize_short = resize_short
self.w = None self.w = None
...@@ -81,6 +127,9 @@ class ResizeImage(object): ...@@ -81,6 +127,9 @@ class ResizeImage(object):
raise OperatorParamError("invalid params for ReisizeImage for '\ raise OperatorParamError("invalid params for ReisizeImage for '\
'both 'size' and 'resize_short' are None") 'both 'size' and 'resize_short' are None")
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, img): def __call__(self, img):
img_h, img_w = img.shape[:2] img_h, img_w = img.shape[:2]
if self.resize_short is not None: if self.resize_short is not None:
...@@ -90,10 +139,7 @@ class ResizeImage(object): ...@@ -90,10 +139,7 @@ class ResizeImage(object):
else: else:
w = self.w w = self.w
h = self.h h = self.h
if self.interpolation is None: return self._resize_func(img, (w, h))
return cv2.resize(img, (w, h))
else:
return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object): class CropImage(object):
...@@ -119,9 +165,12 @@ class CropImage(object): ...@@ -119,9 +165,12 @@ class CropImage(object):
class RandCropImage(object): class RandCropImage(object):
""" random crop image """ """ random crop image """
def __init__(self, size, scale=None, ratio=None, interpolation=-1): def __init__(self,
size,
self.interpolation = interpolation if interpolation >= 0 else None scale=None,
ratio=None,
interpolation=None,
backend="cv2"):
if type(size) is int: if type(size) is int:
self.size = (size, size) # (h, w) self.size = (size, size) # (h, w)
else: else:
...@@ -130,6 +179,9 @@ class RandCropImage(object): ...@@ -130,6 +179,9 @@ class RandCropImage(object):
self.scale = [0.08, 1.0] if scale is None else scale self.scale = [0.08, 1.0] if scale is None else scale
self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, img): def __call__(self, img):
size = self.size size = self.size
scale = self.scale scale = self.scale
...@@ -155,10 +207,8 @@ class RandCropImage(object): ...@@ -155,10 +207,8 @@ class RandCropImage(object):
j = random.randint(0, img_h - h) j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :] img = img[j:j + h, i:i + w, :]
if self.interpolation is None:
return cv2.resize(img, size) return self._resize_func(img, size)
else:
return cv2.resize(img, size, interpolation=self.interpolation)
class RandFlipImage(object): class RandFlipImage(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册