提交 19840cb0 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: to be pythonic

上级 e823f178
......@@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
......@@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
super().__init__(*args, **kwargs)
self.prob = prob
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
......@@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment):
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
......@@ -34,6 +34,23 @@ from .functional import augmentations
from ppcls.utils import logger
def format_data(func):
def warpper(self, data):
if isinstance(data, dict):
img = data["img"]
result = func(self, img)
if not isinstance(result, dict):
result = {"img": result}
return { ** data, ** result}
else:
result = func(self, data)
if isinstance(result, dict):
result = result["img"]
return result
return warpper
class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = {
......@@ -161,8 +178,8 @@ class DecodeImage(object):
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
)
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
if isinstance(img, Image.Image):
assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray):
......@@ -189,12 +206,7 @@ class DecodeImage(object):
if self.channel_first:
img = img.transpose((2, 0, 1))
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
class ResizeImage(object):
......@@ -421,8 +433,8 @@ class RandCropImage(object):
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
size = self.size
scale = self.scale
ratio = self.ratio
......@@ -447,12 +459,7 @@ class RandCropImage(object):
j = random.randint(0, img_h - h)
img = self._resize_func(img[j:j + h, i:i + w, :], size)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
class RandCropImageV2(object):
......@@ -557,8 +564,8 @@ class NormalizeImage(object):
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
......@@ -580,12 +587,7 @@ class NormalizeImage(object):
(img, pad_zeros), axis=2))
img = img.astype(self.output_dtype)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
class ToCHWImage(object):
......@@ -772,15 +774,9 @@ class RandomRot90(object):
def __init__(self):
pass
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
orientation = random.choice([0, 1, 2, 3])
if orientation:
img = np.rot90(img, orientation)
processed_data = {
**
ori_data,
"img": img,
"random_rot90_orientation": orientation
} if isinstance(ori_data, dict) else img
return processed_data
return {"img": img, "random_rot90_orientation": orientation}
......@@ -22,6 +22,8 @@ import random
import numpy as np
from .operators import format_data
class Pixels(object):
def __init__(self, mode="const", mean=[0., 0., 0.]):
......@@ -70,11 +72,10 @@ class RandomErasing(object):
self.attempt = attempt
self.get_pixels = Pixels(mode, mean)
def __call__(self, ori_data):
@format_data
def __call__(self, img):
if random.random() > self.EPSILON:
return ori_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
return img
for _ in range(self.attempt):
if isinstance(img, np.ndarray):
......@@ -107,16 +108,6 @@ class RandomErasing(object):
img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册