提交 19840cb0 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: to be pythonic

上级 e823f178
...@@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage ...@@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90 from ppcls.data.preprocess.ops.operators import RandomRot90
from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
...@@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment): ...@@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.prob = prob self.prob = prob
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
img = Image.fromarray(img) img = Image.fromarray(img)
...@@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment): ...@@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment):
img = super().__call__(img) img = super().__call__(img)
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.asarray(img) img = np.asarray(img)
processed_data = {
** return img
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
...@@ -34,6 +34,23 @@ from .functional import augmentations ...@@ -34,6 +34,23 @@ from .functional import augmentations
from ppcls.utils import logger from ppcls.utils import logger
def format_data(func):
def warpper(self, data):
if isinstance(data, dict):
img = data["img"]
result = func(self, img)
if not isinstance(result, dict):
result = {"img": result}
return { ** data, ** result}
else:
result = func(self, data)
if isinstance(result, dict):
result = result["img"]
return result
return warpper
class UnifiedResize(object): class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2", return_numpy=True): def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = { _cv2_interp_from_str = {
...@@ -161,8 +178,8 @@ class DecodeImage(object): ...@@ -161,8 +178,8 @@ class DecodeImage(object):
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}." f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
) )
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
assert self.backend == "pil", "invalid input 'img' in DecodeImage" assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray): elif isinstance(img, np.ndarray):
...@@ -189,12 +206,7 @@ class DecodeImage(object): ...@@ -189,12 +206,7 @@ class DecodeImage(object):
if self.channel_first: if self.channel_first:
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ResizeImage(object): class ResizeImage(object):
...@@ -421,8 +433,8 @@ class RandCropImage(object): ...@@ -421,8 +433,8 @@ class RandCropImage(object):
self._resize_func = UnifiedResize( self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend) interpolation=interpolation, backend=backend)
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
size = self.size size = self.size
scale = self.scale scale = self.scale
ratio = self.ratio ratio = self.ratio
...@@ -447,12 +459,7 @@ class RandCropImage(object): ...@@ -447,12 +459,7 @@ class RandCropImage(object):
j = random.randint(0, img_h - h) j = random.randint(0, img_h - h)
img = self._resize_func(img[j:j + h, i:i + w, :], size) img = self._resize_func(img[j:j + h, i:i + w, :], size)
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class RandCropImageV2(object): class RandCropImageV2(object):
...@@ -557,8 +564,8 @@ class NormalizeImage(object): ...@@ -557,8 +564,8 @@ class NormalizeImage(object):
self.mean = np.array(mean).reshape(shape).astype('float32') self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
from PIL import Image from PIL import Image
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.array(img) img = np.array(img)
...@@ -580,12 +587,7 @@ class NormalizeImage(object): ...@@ -580,12 +587,7 @@ class NormalizeImage(object):
(img, pad_zeros), axis=2)) (img, pad_zeros), axis=2))
img = img.astype(self.output_dtype) img = img.astype(self.output_dtype)
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ToCHWImage(object): class ToCHWImage(object):
...@@ -772,15 +774,9 @@ class RandomRot90(object): ...@@ -772,15 +774,9 @@ class RandomRot90(object):
def __init__(self): def __init__(self):
pass pass
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
orientation = random.choice([0, 1, 2, 3]) orientation = random.choice([0, 1, 2, 3])
if orientation: if orientation:
img = np.rot90(img, orientation) img = np.rot90(img, orientation)
processed_data = { return {"img": img, "random_rot90_orientation": orientation}
**
ori_data,
"img": img,
"random_rot90_orientation": orientation
} if isinstance(ori_data, dict) else img
return processed_data
...@@ -22,6 +22,8 @@ import random ...@@ -22,6 +22,8 @@ import random
import numpy as np import numpy as np
from .operators import format_data
class Pixels(object): class Pixels(object):
def __init__(self, mode="const", mean=[0., 0., 0.]): def __init__(self, mode="const", mean=[0., 0., 0.]):
...@@ -70,11 +72,10 @@ class RandomErasing(object): ...@@ -70,11 +72,10 @@ class RandomErasing(object):
self.attempt = attempt self.attempt = attempt
self.get_pixels = Pixels(mode, mean) self.get_pixels = Pixels(mode, mean)
def __call__(self, ori_data): @format_data
def __call__(self, img):
if random.random() > self.EPSILON: if random.random() > self.EPSILON:
return ori_data return img
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
for _ in range(self.attempt): for _ in range(self.attempt):
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
...@@ -107,16 +108,6 @@ class RandomErasing(object): ...@@ -107,16 +108,6 @@ class RandomErasing(object):
img[0, x1:x1 + h, y1:y1 + w] = pixels[0] img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else: else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0] img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
processed_data = { return img
**
ori_data, return img
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册